CMS 3D CMS Logo

ONNXRuntime.h
Go to the documentation of this file.
1 /*
2  * ONNXRuntime.h
3  *
4  * A convenience wrapper of the ONNXRuntime C++ API.
5  * Based on https://github.com/microsoft/onnxruntime/blob/master/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp.
6  *
7  * Created on: Jun 28, 2019
8  * Author: hqu
9  */
10 
11 #ifndef PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
12 #define PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
13 
14 #include <vector>
15 #include <map>
16 #include <string>
17 #include <memory>
18 
19 #include "onnxruntime/core/session/onnxruntime_cxx_api.h"
20 
21 namespace cms::Ort {
22 
23  typedef std::vector<std::vector<float>> FloatArrays;
24 
25  class ONNXRuntime {
26  public:
27  ONNXRuntime(const std::string& model_path, const ::Ort::SessionOptions* session_options = nullptr);
28  ONNXRuntime(const ONNXRuntime&) = delete;
29  ONNXRuntime& operator=(const ONNXRuntime&) = delete;
30  ~ONNXRuntime();
31 
32  // Run inference and get outputs
33  // input_names: list of the names of the input nodes.
34  // input_values: list of input arrays for each input node. The order of `input_values` must match `input_names`.
35  // input_shapes: list of `int64_t` arrays specifying the shape of each input node. Can leave empty if the model does not have dynamic axes.
36  // output_names: names of the output nodes to get outputs from. Empty list means all output nodes.
37  // batch_size: number of samples in the batch. Each array in `input_values` must have a shape layout of (batch_size, ...).
38  // Returns: a std::vector<std::vector<float>>, with the order matched to `output_names`.
39  // When `output_names` is empty, will return all outputs ordered as in `getOutputNames()`.
40  FloatArrays run(const std::vector<std::string>& input_names,
41  FloatArrays& input_values,
42  const std::vector<std::vector<int64_t>>& input_shapes = {},
43  const std::vector<std::string>& output_names = {},
44  int64_t batch_size = 1) const;
45 
46  // Get a list of names of all the output nodes
47  const std::vector<std::string>& getOutputNames() const;
48 
49  // Get the shape of a output node
50  // The 0th dim depends on the batch size, therefore is set to -1
51  const std::vector<int64_t>& getOutputShape(const std::string& output_name) const;
52 
53  private:
54  static const ::Ort::Env env_;
55  std::unique_ptr<::Ort::Session> session_;
56 
57  std::vector<std::string> input_node_strings_;
58  std::vector<const char*> input_node_names_;
59  std::map<std::string, std::vector<int64_t>> input_node_dims_;
60 
61  std::vector<std::string> output_node_strings_;
62  std::vector<const char*> output_node_names_;
63  std::map<std::string, std::vector<int64_t>> output_node_dims_;
64  };
65 
66 } // namespace cms::Ort
67 
68 #endif /* PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_ */
std::unique_ptr<::Ort::Session > session_
Definition: ONNXRuntime.h:55
std::map< std::string, std::vector< int64_t > > input_node_dims_
Definition: ONNXRuntime.h:59
std::map< std::string, std::vector< int64_t > > output_node_dims_
Definition: ONNXRuntime.h:63
static const ::Ort::Env env_
Definition: ONNXRuntime.h:54
std::vector< std::vector< float > > FloatArrays
Definition: ONNXRuntime.h:23
FloatArrays run(const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::vector< int64_t >> &input_shapes={}, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const
Definition: ONNXRuntime.cc:80
ONNXRuntime(const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr)
std::vector< const char * > output_node_names_
Definition: ONNXRuntime.h:62
const std::vector< std::string > & getOutputNames() const
Definition: ONNXRuntime.cc:153
const std::vector< int64_t > & getOutputShape(const std::string &output_name) const
Definition: ONNXRuntime.cc:161
std::vector< std::string > input_node_strings_
Definition: ONNXRuntime.h:57
std::vector< const char * > input_node_names_
Definition: ONNXRuntime.h:58
std::vector< std::string > output_node_strings_
Definition: ONNXRuntime.h:61
ONNXRuntime & operator=(const ONNXRuntime &)=delete