CMS 3D CMS Logo

Predictor.cc
Go to the documentation of this file.
1 /*
2  * MXNetCppPredictor.cc
3  *
4  * Created on: Jul 19, 2018
5  * Author: hqu
6  */
7 
9 
10 #include <cassert>
11 #include <memory>
12 
14 
15 namespace mxnet {
16 
17  namespace cpp {
18 
20 
21  Block::Block(const std::string& symbol_file, const std::string& param_file) {
22  // load the symbol
23  sym_ = Symbol::Load(symbol_file);
24  // load the parameters
25  load_parameters(param_file);
26  }
27 
29 
30  void Block::load_parameters(const std::string& param_file) {
31  std::map<std::string, NDArray> paramters;
32  NDArray::Load(param_file, nullptr, &paramters);
33  for (const auto& k : paramters) {
34  if (k.first.substr(0, 4) == "aux:") {
35  auto name = k.first.substr(4, k.first.size() - 4);
36  aux_map_[name] = k.second;
37  }
38  if (k.first.substr(0, 4) == "arg:") {
39  auto name = k.first.substr(4, k.first.size() - 4);
40  arg_map_[name] = k.second;
41  }
42  }
43  }
44 
46  const Context Predictor::context_ = Context(DeviceType::kCPU, 0);
47 
49 
51  : sym_(block.symbol()), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {}
52 
53  Predictor::Predictor(const Block& block, const std::string& output_node)
54  : sym_(block.symbol(output_node)), arg_map_(block.arg_map()), aux_map_(block.aux_map()) {}
55 
57 
58  void Predictor::set_input_shapes(const std::vector<std::string>& input_names,
59  const std::vector<std::vector<mx_uint> >& input_shapes) {
60  assert(input_names.size() == input_shapes.size());
62  // init the input NDArrays and add them to the arg_map
63  for (unsigned i = 0; i < input_names_.size(); ++i) {
64  const auto& name = input_names_[i];
65  arg_map_.emplace(name, NDArray(input_shapes[i], context_, false));
66  }
67  }
68 
69  const std::vector<float>& Predictor::predict(const std::vector<std::vector<mx_float> >& input_data) {
70  assert(input_names_.size() == input_data.size());
71 
72  try {
73  // create the executor (if not done yet)
74  if (!exec_) {
75  bind_executor();
76  }
77  assert(exec_);
78  // set the inputs
79  for (unsigned i = 0; i < input_names_.size(); ++i) {
80  const auto& name = input_names_[i];
81  arg_map_[name].SyncCopyFromCPU(input_data[i]);
82  }
83  // run forward
84  exec_->Forward(false);
85  // copy the output to pred_
86  exec_->outputs[0].SyncCopyToCPU(&pred_);
87  return pred_;
88  } catch (const dmlc::Error& e) {
89  throw cms::Exception("RuntimeError") << e.what() << MXGetLastError();
90  }
91  }
92 
94  // acquire lock
95  std::lock_guard<std::mutex> lock(mutex_);
96 
97  // infer shapes
98  const auto arg_name_list = sym_.ListArguments();
99  std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
100  std::map<std::string, std::vector<mx_uint> > arg_shapes;
101 
102  for (const auto& arg_name : arg_name_list) {
103  auto iter = arg_map_.find(arg_name);
104  if (iter != arg_map_.end()) {
105  arg_shapes[arg_name] = iter->second.GetShape();
106  }
107  }
108  sym_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
109 
110  // init argument arrays
111  std::vector<NDArray> arg_arrays;
112  for (size_t i = 0; i < in_shapes.size(); ++i) {
113  const auto& shape = in_shapes[i];
114  const auto& arg_name = arg_name_list[i];
115  auto iter_arg = arg_map_.find(arg_name);
116  if (iter_arg != arg_map_.end()) {
117  arg_arrays.push_back(iter_arg->second);
118  } else {
119  arg_arrays.push_back(NDArray(shape, context_, false));
120  }
121  }
122  std::vector<NDArray> grad_arrays(arg_arrays.size());
123  std::vector<OpReqType> grad_reqs(arg_arrays.size(), kNullOp);
124 
125  // init auxiliary array
126  std::vector<NDArray> aux_arrays;
127  const auto aux_name_list = sym_.ListAuxiliaryStates();
128  for (size_t i = 0; i < aux_shapes.size(); ++i) {
129  const auto& shape = aux_shapes[i];
130  const auto& aux_name = aux_name_list[i];
131  auto iter_aux = aux_map_.find(aux_name);
132  if (iter_aux != aux_map_.end()) {
133  aux_arrays.push_back(iter_aux->second);
134  } else {
135  aux_arrays.push_back(NDArray(shape, context_, false));
136  }
137  }
138 
139  // bind executor
140  exec_ = std::make_unique<Executor>(sym_, context_, arg_arrays, grad_arrays, grad_reqs, aux_arrays);
141  }
142 
143  } // namespace cpp
144 
145 } /* namespace mxnet */
edm::ErrorSummaryEntry Error
virtual ~Block()
Definition: Predictor.cc:28
std::map< std::string, NDArray > aux_map_
Definition: Predictor.h:45
static std::mutex mutex
Definition: Proxy.cc:8
std::map< std::string, NDArray > arg_map_
Definition: Predictor.h:43
void load_parameters(const std::string &param_file)
Definition: Predictor.cc:30
virtual ~Predictor()
Definition: Predictor.cc:56
std::map< std::string, NDArray > aux_map_
Definition: Predictor.h:78
assert(be >=bs)
std::vector< float > pred_
Definition: Predictor.h:80
static const Context context_
Definition: Predictor.h:70
std::map< std::string, NDArray > arg_map_
Definition: Predictor.h:76
std::unique_ptr< Executor > exec_
Definition: Predictor.h:72
const std::vector< float > & predict(const std::vector< std::vector< mx_float >> &input_data)
Definition: Predictor.cc:69
void set_input_shapes(const std::vector< std::string > &input_names, const std::vector< std::vector< mx_uint >> &input_shapes)
Definition: Predictor.cc:58
std::vector< std::string > input_names_
Definition: Predictor.h:82
static std::mutex mutex_
Definition: Predictor.h:65