CMS 3D CMS Logo

Predictor.h
Go to the documentation of this file.
1 /*
2  * MXNetCppPredictor.h
3  *
4  * Created on: Jul 19, 2018
5  * Author: hqu
6  */
7 
8 #ifndef PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_
9 #define PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_
10 
11 #include <map>
12 #include <vector>
13 #include <memory>
14 #include <mutex>
15 
16 #include "mxnet-cpp/MxNetCpp.h"
17 
18 namespace mxnet {
19 
20 namespace cpp {
21 
22 // note: Most of the objects in mxnet::cpp are effective just shared_ptr's
23 
24 // Simple class to hold MXNet model (symbol + params)
25 // designed to be sharable by multiple threads
26 class Block {
27 public:
28  Block();
29  Block(const std::string &symbol_file, const std::string &param_file);
30  virtual ~Block();
31 
32  const Symbol& symbol() const { return sym_; }
33  Symbol symbol(const std::string &output_node) const { return sym_.GetInternals()[output_node]; }
34  const std::map<std::string, NDArray>& arg_map() const { return arg_map_; }
35  const std::map<std::string, NDArray>& aux_map() const { return aux_map_; }
36 
37 private:
38  void load_parameters(const std::string& param_file);
39 
40  // symbol
41  Symbol sym_;
42  // argument arrays
43  std::map<std::string, NDArray> arg_map_;
44  // auxiliary arrays
45  std::map<std::string, NDArray> aux_map_;
46 };
47 
48 // Simple helper class to run prediction
49 // this cannot be shared between threads
50 class Predictor {
51 public:
52  Predictor();
53  Predictor(const Block &block);
54  Predictor(const Block &block, const std::string &output_node);
55  virtual ~Predictor();
56 
57  // set input array shapes
58  void set_input_shapes(const std::vector<std::string>& input_names, const std::vector<std::vector<mx_uint>>& input_shapes);
59 
60  // run prediction
61  const std::vector<float>& predict(const std::vector<std::vector<mx_float>>& input_data);
62 
63 private:
65 
66  void bind_executor();
67 
68  // context
69  static const Context context_;
70  // executor
71  std::unique_ptr<Executor> exec_;
72  // symbol
73  Symbol sym_;
74  // argument arrays
75  std::map<std::string, NDArray> arg_map_;
76  // auxiliary arrays
77  std::map<std::string, NDArray> aux_map_;
78  // output of the prediction
79  std::vector<float> pred_;
80  // names of the input nodes
81  std::vector<std::string> input_names_;
82 
83 };
84 
85 } /* namespace cpp */
86 } /* namespace mxnet */
87 
88 #endif /* PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_ */
const std::map< std::string, NDArray > & aux_map() const
Definition: Predictor.h:35
virtual ~Block()
Definition: Predictor.cc:28
std::map< std::string, NDArray > aux_map_
Definition: Predictor.h:45
static boost::mutex mutex
Definition: LHEProxy.cc:11
std::map< std::string, NDArray > arg_map_
Definition: Predictor.h:43
void load_parameters(const std::string &param_file)
Definition: Predictor.cc:31
std::map< std::string, NDArray > aux_map_
Definition: Predictor.h:77
std::vector< float > pred_
Definition: Predictor.h:79
const std::map< std::string, NDArray > & arg_map() const
Definition: Predictor.h:34
Symbol symbol(const std::string &output_node) const
Definition: Predictor.h:33
const Symbol & symbol() const
Definition: Predictor.h:32
static const Context context_
Definition: Predictor.h:69
std::map< std::string, NDArray > arg_map_
Definition: Predictor.h:75
std::unique_ptr< Executor > exec_
Definition: Predictor.h:71
std::vector< std::string > input_names_
Definition: Predictor.h:81
static std::mutex mutex_
Definition: Predictor.h:64