CMS 3D CMS Logo

All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
Predictor.cc
Go to the documentation of this file.
1 /*
2  * MXNetCppPredictor.cc
3  *
4  * Created on: Jul 19, 2018
5  * Author: hqu
6  */
7 
9 
10 #include <cassert>
12 
13 namespace mxnet {
14 
15  namespace cpp {
16 
18 
19  Block::Block(const std::string& symbol_file, const std::string& param_file) {
20  // load the symbol
21  sym_ = Symbol::Load(symbol_file);
22  // load the parameters
23  load_parameters(param_file);
24  }
25 
27 
28  void Block::load_parameters(const std::string& param_file) {
29  std::map<std::string, NDArray> paramters;
30  NDArray::Load(param_file, nullptr, &paramters);
31  for (const auto& k : paramters) {
32  if (k.first.substr(0, 4) == "aux:") {
33  auto name = k.first.substr(4, k.first.size() - 4);
34  aux_map_[name] = k.second;
35  }
36  if (k.first.substr(0, 4) == "arg:") {
37  auto name = k.first.substr(4, k.first.size() - 4);
38  arg_map_[name] = k.second;
39  }
40  }
41  }
42 
44  const Context Predictor::context_ = Context(DeviceType::kCPU, 0);
45 
47 
49  : sym_(block.symbol()), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {}
50 
51  Predictor::Predictor(const Block& block, const std::string& output_node)
52  : sym_(block.symbol(output_node)), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {}
53 
55 
56  void Predictor::set_input_shapes(const std::vector<std::string>& input_names,
57  const std::vector<std::vector<mx_uint> >& input_shapes) {
58  assert(input_names.size() == input_shapes.size());
60  // init the input NDArrays and add them to the arg_map
61  for (unsigned i = 0; i < input_names_.size(); ++i) {
62  const auto& name = input_names_[i];
63  arg_map_.emplace(name, NDArray(input_shapes[i], context_, false));
64  }
65  }
66 
67  const std::vector<float>& Predictor::predict(const std::vector<std::vector<mx_float> >& input_data) {
68  assert(input_names_.size() == input_data.size());
69 
70  try {
71  // create the executor (if not done yet)
72  if (!exec_) {
73  bind_executor();
74  }
75  assert(exec_);
76  // set the inputs
77  for (unsigned i = 0; i < input_names_.size(); ++i) {
78  const auto& name = input_names_[i];
79  arg_map_[name].SyncCopyFromCPU(input_data[i]);
80  }
81  // run forward
82  exec_->Forward(false);
83  // copy the output to pred_
84  exec_->outputs[0].SyncCopyToCPU(&pred_);
85  return pred_;
86  } catch (const dmlc::Error& e) {
87  throw cms::Exception("RuntimeError") << e.what() << MXGetLastError();
88  }
89  }
90 
92  // acquire lock
93  std::lock_guard<std::mutex> lock(mutex_);
94 
95  // infer shapes
96  const auto arg_name_list = sym_.ListArguments();
97  std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
98  std::map<std::string, std::vector<mx_uint> > arg_shapes;
99 
100  for (const auto& arg_name : arg_name_list) {
101  auto iter = arg_map_.find(arg_name);
102  if (iter != arg_map_.end()) {
103  arg_shapes[arg_name] = iter->second.GetShape();
104  }
105  }
106  sym_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
107 
108  // init argument arrays
109  std::vector<NDArray> arg_arrays;
110  for (size_t i = 0; i < in_shapes.size(); ++i) {
111  const auto& shape = in_shapes[i];
112  const auto& arg_name = arg_name_list[i];
113  auto iter_arg = arg_map_.find(arg_name);
114  if (iter_arg != arg_map_.end()) {
115  arg_arrays.push_back(iter_arg->second);
116  } else {
117  arg_arrays.push_back(NDArray(shape, context_, false));
118  }
119  }
120  std::vector<NDArray> grad_arrays(arg_arrays.size());
121  std::vector<OpReqType> grad_reqs(arg_arrays.size(), kNullOp);
122 
123  // init auxiliary array
124  std::vector<NDArray> aux_arrays;
125  const auto aux_name_list = sym_.ListAuxiliaryStates();
126  for (size_t i = 0; i < aux_shapes.size(); ++i) {
127  const auto& shape = aux_shapes[i];
128  const auto& aux_name = aux_name_list[i];
129  auto iter_aux = aux_map_.find(aux_name);
130  if (iter_aux != aux_map_.end()) {
131  aux_arrays.push_back(iter_aux->second);
132  } else {
133  aux_arrays.push_back(NDArray(shape, context_, false));
134  }
135  }
136 
137  // bind executor
138  exec_.reset(new Executor(sym_, context_, arg_arrays, grad_arrays, grad_reqs, aux_arrays));
139  }
140 
141  } // namespace cpp
142 
143 } /* namespace mxnet */
mxnet
Definition: Predictor.h:18
mxnet::cpp::Predictor::Predictor
Predictor()
Definition: Predictor.cc:46
mps_fire.i
i
Definition: mps_fire.py:355
Predictor.h
mxnet::cpp::Predictor::mutex_
static std::mutex mutex_
Definition: Predictor.h:65
mxnet::cpp::Block::Block
Block()
Definition: Predictor.cc:17
mxnet::cpp::Predictor::~Predictor
virtual ~Predictor()
Definition: Predictor.cc:54
cms::cuda::assert
assert(be >=bs)
Context
mxnet::cpp::Block
Definition: Predictor.h:26
mxnet::cpp::Predictor::bind_executor
void bind_executor()
Definition: Predictor.cc:91
mxnet::cpp::Predictor::context_
static const Context context_
Definition: Predictor.h:70
mxnet::cpp::Predictor::set_input_shapes
void set_input_shapes(const std::vector< std::string > &input_names, const std::vector< std::vector< mx_uint >> &input_shapes)
Definition: Predictor.cc:56
mxnet::cpp::Predictor::exec_
std::unique_ptr< Executor > exec_
Definition: Predictor.h:72
dqmdumpme.k
k
Definition: dqmdumpme.py:60
mutex
static boost::mutex mutex
Definition: Proxy.cc:9
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
mxnet::cpp::Block::load_parameters
void load_parameters(const std::string &param_file)
Definition: Predictor.cc:28
CommonMethods.lock
def lock()
Definition: CommonMethods.py:82
mxnet::cpp::Predictor::input_names_
std::vector< std::string > input_names_
Definition: Predictor.h:82
mxnet::cpp::Block::~Block
virtual ~Block()
Definition: Predictor.cc:26
leef::Error
edm::ErrorSummaryEntry Error
Definition: LogErrorEventFilter.cc:29
groupFilesInBlocks.block
block
Definition: groupFilesInBlocks.py:150
mxnet::cpp::Predictor::predict
const std::vector< float > & predict(const std::vector< std::vector< mx_float >> &input_data)
Definition: Predictor.cc:67
mxnet::cpp::Block::arg_map_
std::map< std::string, NDArray > arg_map_
Definition: Predictor.h:43
mxnet::cpp::Block::sym_
Symbol sym_
Definition: Predictor.h:41
mxnet::cpp::Block::aux_map_
std::map< std::string, NDArray > aux_map_
Definition: Predictor.h:45
Exception
Definition: hltDiff.cc:246
Skims_PA_cff.name
name
Definition: Skims_PA_cff.py:17
Exception.h
mxnet::cpp::Predictor::pred_
std::vector< float > pred_
Definition: Predictor.h:80
mxnet::cpp::Predictor::arg_map_
std::map< std::string, NDArray > arg_map_
Definition: Predictor.h:76
pfParticleNetPreprocessParams_cfi.input_names
input_names
Definition: pfParticleNetPreprocessParams_cfi.py:4
mxnet::cpp::Predictor::aux_map_
std::map< std::string, NDArray > aux_map_
Definition: Predictor.h:78
MillePedeFileConverter_cfg.e
e
Definition: MillePedeFileConverter_cfg.py:37
mxnet::cpp::Predictor::sym_
Symbol sym_
Definition: Predictor.h:74