CMS 3D CMS Logo

ONNXRuntime.h
Go to the documentation of this file.
1 /*
2  * ONNXRuntime.h
3  *
4  * A convenience wrapper of the ONNXRuntime C++ API.
5  * Based on https://github.com/microsoft/onnxruntime/blob/master/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp.
6  *
7  * Created on: Jun 28, 2019
8  * Author: hqu
9  */
10 
11 #ifndef PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
12 #define PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
13 
14 #include <vector>
15 #include <map>
16 #include <string>
17 #include <memory>
18 
19 #include "onnxruntime/core/session/onnxruntime_cxx_api.h"
20 
21 namespace cms::Ort {
22 
23  typedef std::vector<std::vector<float>> FloatArrays;
24 
25  enum class Backend {
26  cpu,
27  cuda,
28  };
29 
30  class ONNXRuntime {
31  public:
32  ONNXRuntime(const std::string& model_path, const ::Ort::SessionOptions* session_options = nullptr);
33  ONNXRuntime(const ONNXRuntime&) = delete;
34  ONNXRuntime& operator=(const ONNXRuntime&) = delete;
35  ~ONNXRuntime();
36 
37  static ::Ort::SessionOptions defaultSessionOptions(Backend backend = Backend::cpu);
38 
39  // Run inference and get outputs
40  // input_names: list of the names of the input nodes.
41  // input_values: list of input arrays for each input node. The order of `input_values` must match `input_names`.
42  // input_shapes: list of `int64_t` arrays specifying the shape of each input node. Can leave empty if the model does not have dynamic axes.
43  // output_names: names of the output nodes to get outputs from. Empty list means all output nodes.
44  // batch_size: number of samples in the batch. Each array in `input_values` must have a shape layout of (batch_size, ...).
45  // Returns: a std::vector<std::vector<float>>, with the order matched to `output_names`.
46  // When `output_names` is empty, will return all outputs ordered as in `getOutputNames()`.
47  FloatArrays run(const std::vector<std::string>& input_names,
48  FloatArrays& input_values,
49  const std::vector<std::vector<int64_t>>& input_shapes = {},
50  const std::vector<std::string>& output_names = {},
51  int64_t batch_size = 1) const;
52 
53  // Get a list of names of all the output nodes
54  const std::vector<std::string>& getOutputNames() const;
55 
56  // Get the shape of a output node
57  // The 0th dim depends on the batch size, therefore is set to -1
58  const std::vector<int64_t>& getOutputShape(const std::string& output_name) const;
59 
60  private:
61  static const ::Ort::Env env_;
62  std::unique_ptr<::Ort::Session> session_;
63 
64  std::vector<std::string> input_node_strings_;
65  std::vector<const char*> input_node_names_;
66  std::map<std::string, std::vector<int64_t>> input_node_dims_;
67 
68  std::vector<std::string> output_node_strings_;
69  std::vector<const char*> output_node_names_;
70  std::map<std::string, std::vector<int64_t>> output_node_dims_;
71  };
72 
73 } // namespace cms::Ort
74 
75 #endif /* PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_ */
std::unique_ptr<::Ort::Session > session_
Definition: ONNXRuntime.h:62
std::map< std::string, std::vector< int64_t > > input_node_dims_
Definition: ONNXRuntime.h:66
::Ort::SessionOptions defaultSessionOptions(Backend backend=Backend::cpu)
Definition: ONNXRuntime.cc:79
std::map< std::string, std::vector< int64_t > > output_node_dims_
Definition: ONNXRuntime.h:70
static const ::Ort::Env env_
Definition: ONNXRuntime.h:61
std::vector< std::vector< float > > FloatArrays
Definition: ONNXRuntime.h:23
const std::vector< std::string > & getOutputNames() const
Definition: ONNXRuntime.cc:167
ONNXRuntime(const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr)
const std::vector< int64_t > & getOutputShape(const std::string &output_name) const
Definition: ONNXRuntime.cc:175
std::vector< const char * > output_node_names_
Definition: ONNXRuntime.h:69
std::vector< std::string > input_node_strings_
Definition: ONNXRuntime.h:64
std::vector< const char * > input_node_names_
Definition: ONNXRuntime.h:65
std::vector< std::string > output_node_strings_
Definition: ONNXRuntime.h:68
ONNXRuntime & operator=(const ONNXRuntime &)=delete
FloatArrays run(const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::vector< int64_t >> &input_shapes={}, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const
Definition: ONNXRuntime.cc:90