CMS 3D CMS Logo

Predictor.h
Go to the documentation of this file.
1 /*
2  * MXNetCppPredictor.h
3  *
4  * Created on: Jul 19, 2018
5  * Author: hqu
6  */
7 
8 #ifndef PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_
9 #define PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_
10 
11 #include <map>
12 #include <vector>
13 #include <memory>
14 #include <mutex>
15 
16 #include "mxnet-cpp/MxNetCpp.h"
17 
18 namespace mxnet {
19 
20  namespace cpp {
21 
22  // note: Most of the objects in mxnet::cpp are effective just shared_ptr's
23 
24  // Simple class to hold MXNet model (symbol + params)
25  // designed to be sharable by multiple threads
26  class Block {
27  public:
28  Block();
29  Block(const std::string& symbol_file, const std::string& param_file);
30  virtual ~Block();
31 
32  const Symbol& symbol() const { return sym_; }
33  Symbol symbol(const std::string& output_node) const { return sym_.GetInternals()[output_node]; }
34  const std::map<std::string, NDArray>& arg_map() const { return arg_map_; }
35  const std::map<std::string, NDArray>& aux_map() const { return aux_map_; }
36 
37  private:
38  void load_parameters(const std::string& param_file);
39 
40  // symbol
41  Symbol sym_;
42  // argument arrays
43  std::map<std::string, NDArray> arg_map_;
44  // auxiliary arrays
45  std::map<std::string, NDArray> aux_map_;
46  };
47 
48  // Simple helper class to run prediction
49  // this cannot be shared between threads
50  class Predictor {
51  public:
52  Predictor();
53  Predictor(const Block& block);
54  Predictor(const Block& block, const std::string& output_node);
55  virtual ~Predictor();
56 
57  // set input array shapes
58  void set_input_shapes(const std::vector<std::string>& input_names,
59  const std::vector<std::vector<mx_uint>>& input_shapes);
60 
61  // run prediction
62  const std::vector<float>& predict(const std::vector<std::vector<mx_float>>& input_data);
63 
64  private:
66 
67  void bind_executor();
68 
69  // context
70  static const Context context_;
71  // executor
72  std::unique_ptr<Executor> exec_;
73  // symbol
74  Symbol sym_;
75  // argument arrays
76  std::map<std::string, NDArray> arg_map_;
77  // auxiliary arrays
78  std::map<std::string, NDArray> aux_map_;
79  // output of the prediction
80  std::vector<float> pred_;
81  // names of the input nodes
82  std::vector<std::string> input_names_;
83  };
84 
85  } /* namespace cpp */
86 } /* namespace mxnet */
87 
88 #endif /* PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_ */
virtual ~Block()
Definition: Predictor.cc:28
std::map< std::string, NDArray > aux_map_
Definition: Predictor.h:45
static std::mutex mutex
Definition: Proxy.cc:8
std::map< std::string, NDArray > arg_map_
Definition: Predictor.h:43
void load_parameters(const std::string &param_file)
Definition: Predictor.cc:30
virtual ~Predictor()
Definition: Predictor.cc:56
std::map< std::string, NDArray > aux_map_
Definition: Predictor.h:78
Symbol symbol(const std::string &output_node) const
Definition: Predictor.h:33
std::vector< float > pred_
Definition: Predictor.h:80
const Symbol & symbol() const
Definition: Predictor.h:32
const std::map< std::string, NDArray > & aux_map() const
Definition: Predictor.h:35
const std::map< std::string, NDArray > & arg_map() const
Definition: Predictor.h:34
static const Context context_
Definition: Predictor.h:70
std::map< std::string, NDArray > arg_map_
Definition: Predictor.h:76
std::unique_ptr< Executor > exec_
Definition: Predictor.h:72
const std::vector< float > & predict(const std::vector< std::vector< mx_float >> &input_data)
Definition: Predictor.cc:69
void set_input_shapes(const std::vector< std::string > &input_names, const std::vector< std::vector< mx_uint >> &input_shapes)
Definition: Predictor.cc:58
std::vector< std::string > input_names_
Definition: Predictor.h:82
static std::mutex mutex_
Definition: Predictor.h:65