CMS 3D CMS Logo

Predictor.cc
Go to the documentation of this file.
1 /*
2  * MXNetCppPredictor.cc
3  *
4  * Created on: Jul 19, 2018
5  * Author: hqu
6  */
7 
9 
10 #include <cassert>
12 
13 
14 namespace mxnet {
15 
16 namespace cpp {
17 
19 }
20 
21 Block::Block(const std::string& symbol_file, const std::string& param_file) {
22  // load the symbol
23  sym_ = Symbol::Load(symbol_file);
24  // load the parameters
25  load_parameters(param_file);
26 }
27 
29 }
30 
31 void Block::load_parameters(const std::string& param_file) {
32  std::map<std::string, NDArray> paramters;
33  NDArray::Load(param_file, nullptr, &paramters);
34  for (const auto &k : paramters) {
35  if (k.first.substr(0, 4) == "aux:") {
36  auto name = k.first.substr(4, k.first.size() - 4);
37  aux_map_[name] = k.second;
38  }
39  if (k.first.substr(0, 4) == "arg:") {
40  auto name = k.first.substr(4, k.first.size() - 4);
41  arg_map_[name] = k.second;
42  }
43  }
44 }
45 
47 const Context Predictor::context_ = Context(DeviceType::kCPU, 0);
48 
50 }
51 
53 : sym_(block.symbol()), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {
54 }
55 
56 Predictor::Predictor(const Block &block, const std::string &output_node)
57 : sym_(block.symbol(output_node)), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {
58 }
59 
61 }
62 
63 void Predictor::set_input_shapes(const std::vector<std::string>& input_names, const std::vector<std::vector<mx_uint> >& input_shapes) {
64  assert(input_names.size() == input_shapes.size());
66  // init the input NDArrays and add them to the arg_map
67  for (unsigned i=0; i<input_names_.size(); ++i){
68  const auto& name = input_names_[i];
69  arg_map_.emplace(name, NDArray(input_shapes[i], context_, false));
70  }
71 }
72 
73 const std::vector<float>& Predictor::predict(const std::vector<std::vector<mx_float> >& input_data) {
74  assert(input_names_.size() == input_data.size());
75 
76  try {
77  // create the executor (if not done yet)
78  if (!exec_) { bind_executor(); }
79  assert(exec_);
80  // set the inputs
81  for (unsigned i=0; i<input_names_.size(); ++i){
82  const auto& name = input_names_[i];
83  arg_map_[name].SyncCopyFromCPU(input_data[i]);
84  }
85  // run forward
86  exec_->Forward(false);
87  // copy the output to pred_
88  exec_->outputs[0].SyncCopyToCPU(&pred_);
89  return pred_;
90  }catch(const dmlc::Error &e){
91  throw cms::Exception("RuntimeError") << e.what() << MXGetLastError();
92  }
93 }
94 
96  // acquire lock
97  std::lock_guard<std::mutex> lock(mutex_);
98 
99  // infer shapes
100  const auto arg_name_list = sym_.ListArguments();
101  std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
102  std::map<std::string, std::vector<mx_uint> > arg_shapes;
103 
104  for (const auto &arg_name : arg_name_list) {
105  auto iter = arg_map_.find(arg_name);
106  if (iter != arg_map_.end()) {
107  arg_shapes[arg_name] = iter->second.GetShape();
108  }
109  }
110  sym_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
111 
112  // init argument arrays
113  std::vector<NDArray> arg_arrays;
114  for (size_t i = 0; i < in_shapes.size(); ++i) {
115  const auto &shape = in_shapes[i];
116  const auto &arg_name = arg_name_list[i];
117  auto iter_arg = arg_map_.find(arg_name);
118  if (iter_arg != arg_map_.end()) {
119  arg_arrays.push_back(iter_arg->second);
120  } else {
121  arg_arrays.push_back(NDArray(shape, context_, false));
122  }
123  }
124  std::vector<NDArray> grad_arrays(arg_arrays.size());
125  std::vector<OpReqType> grad_reqs(arg_arrays.size(), kNullOp);
126 
127  // init auxiliary array
128  std::vector<NDArray> aux_arrays;
129  const auto aux_name_list = sym_.ListAuxiliaryStates();
130  for (size_t i = 0; i < aux_shapes.size(); ++i) {
131  const auto &shape = aux_shapes[i];
132  const auto &aux_name = aux_name_list[i];
133  auto iter_aux = aux_map_.find(aux_name);
134  if (iter_aux != aux_map_.end()) {
135  aux_arrays.push_back(iter_aux->second);
136  } else {
137  aux_arrays.push_back(NDArray(shape, context_, false));
138  }
139  }
140 
141  // bind executor
142  exec_.reset(new Executor(sym_, context_, arg_arrays, grad_arrays, grad_reqs, aux_arrays));
143 }
144 
145 }
146 
147 } /* namespace mxnet */
148 
const std::map< std::string, NDArray > & aux_map() const
Definition: Predictor.h:35
virtual ~Block()
Definition: Predictor.cc:28
std::map< std::string, NDArray > aux_map_
Definition: Predictor.h:45
static boost::mutex mutex
Definition: LHEProxy.cc:11
std::map< std::string, NDArray > arg_map_
Definition: Predictor.h:43
void load_parameters(const std::string &param_file)
Definition: Predictor.cc:31
virtual ~Predictor()
Definition: Predictor.cc:60
std::map< std::string, NDArray > aux_map_
Definition: Predictor.h:77
std::vector< float > pred_
Definition: Predictor.h:79
const std::map< std::string, NDArray > & arg_map() const
Definition: Predictor.h:34
int k[5][pyjets_maxn]
const Symbol & symbol() const
Definition: Predictor.h:32
static const Context context_
Definition: Predictor.h:69
std::map< std::string, NDArray > arg_map_
Definition: Predictor.h:75
std::unique_ptr< Executor > exec_
Definition: Predictor.h:71
const std::vector< float > & predict(const std::vector< std::vector< mx_float >> &input_data)
Definition: Predictor.cc:73
void set_input_shapes(const std::vector< std::string > &input_names, const std::vector< std::vector< mx_uint >> &input_shapes)
Definition: Predictor.cc:63
std::vector< std::string > input_names_
Definition: Predictor.h:81
static std::mutex mutex_
Definition: Predictor.h:64