CMS 3D CMS Logo

ONNXRuntime.h
Go to the documentation of this file.
1 /*
2  * ONNXRuntime.h
3  *
4  * A convenience wrapper of the ONNXRuntime C++ API.
5  * Based on https://github.com/microsoft/onnxruntime/blob/master/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp.
6  *
7  * Created on: Jun 28, 2019
8  * Author: hqu
9  */
10 
11 #ifndef PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
12 #define PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
13 
14 #include <vector>
15 #include <map>
16 #include <string>
17 #include <memory>
18 
19 #include "onnxruntime/core/session/onnxruntime_cxx_api.h"
20 
21 namespace cms::Ort {
22 
23  typedef std::vector<std::vector<float>> FloatArrays;
24 
25  class ONNXRuntime {
26  public:
27  ONNXRuntime(const std::string& model_path, const ::Ort::SessionOptions* session_options = nullptr);
28  ONNXRuntime(const ONNXRuntime&) = delete;
29  ONNXRuntime& operator=(const ONNXRuntime&) = delete;
30  ~ONNXRuntime();
31 
32  // Run inference and get outputs
33  // input_names: list of the names of the input nodes.
34  // input_values: list of input arrays for each input node. The order of `input_values` must match `input_names`.
35  // output_names: names of the output nodes to get outputs from. Empty list means all output nodes.
36  // batch_size: number of samples in the batch. Each array in `input_values` must have a shape layout of (batch_size, ...).
37  // Returns: a std::vector<std::vector<float>>, with the order matched to `output_names`.
38  // When `output_names` is empty, will return all outputs ordered as in `getOutputNames()`.
39  FloatArrays run(const std::vector<std::string>& input_names,
40  FloatArrays& input_values,
41  const std::vector<std::string>& output_names = {},
42  int64_t batch_size = 1) const;
43 
44  // Get a list of names of all the output nodes
45  const std::vector<std::string>& getOutputNames() const;
46 
47  // Get the shape of a output node
48  // The 0th dim depends on the batch size, therefore is set to -1
49  const std::vector<int64_t>& getOutputShape(const std::string& output_name) const;
50 
51  private:
52  static const ::Ort::Env env_;
53  std::unique_ptr<::Ort::Session> session_;
54 
55  std::vector<std::string> input_node_strings_;
56  std::vector<const char*> input_node_names_;
57  std::map<std::string, std::vector<int64_t>> input_node_dims_;
58 
59  std::vector<std::string> output_node_strings_;
60  std::vector<const char*> output_node_names_;
61  std::map<std::string, std::vector<int64_t>> output_node_dims_;
62  };
63 
64 } // namespace cms::Ort
65 
66 #endif /* PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_ */
cms::Ort::ONNXRuntime::operator=
ONNXRuntime & operator=(const ONNXRuntime &)=delete
cms::Ort::ONNXRuntime::session_
std::unique_ptr<::Ort::Session > session_
Definition: ONNXRuntime.h:53
hltPfDeepFlavourJetTags_cfi.output_names
output_names
Definition: hltPfDeepFlavourJetTags_cfi.py:21
cms::Ort::ONNXRuntime::env_
static const ::Ort::Env env_
Definition: ONNXRuntime.h:52
cms::Ort::ONNXRuntime::ONNXRuntime
ONNXRuntime(const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr)
cms::Ort::ONNXRuntime::run
FloatArrays run(const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const
Definition: ONNXRuntime.cc:83
cms::Ort::ONNXRuntime::getOutputNames
const std::vector< std::string > & getOutputNames() const
Definition: ONNXRuntime.cc:147
cms::Ort
Definition: ONNXRuntime.h:21
cms::Ort::ONNXRuntime::input_node_strings_
std::vector< std::string > input_node_strings_
Definition: ONNXRuntime.h:55
cms::Ort::ONNXRuntime::input_node_names_
std::vector< const char * > input_node_names_
Definition: ONNXRuntime.h:56
cms::Ort::ONNXRuntime::output_node_strings_
std::vector< std::string > output_node_strings_
Definition: ONNXRuntime.h:59
cms::Ort::ONNXRuntime::input_node_dims_
std::map< std::string, std::vector< int64_t > > input_node_dims_
Definition: ONNXRuntime.h:57
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
hltPfDeepFlavourJetTags_cfi.input_names
input_names
Definition: hltPfDeepFlavourJetTags_cfi.py:12
cms::Ort::ONNXRuntime::getOutputShape
const std::vector< int64_t > & getOutputShape(const std::string &output_name) const
Definition: ONNXRuntime.cc:155
cms::Ort::FloatArrays
std::vector< std::vector< float > > FloatArrays
Definition: ONNXRuntime.h:23
cms::Ort::ONNXRuntime
Definition: ONNXRuntime.h:25
hltPfDeepFlavourJetTags_cfi.model_path
model_path
Definition: hltPfDeepFlavourJetTags_cfi.py:20
cms::Ort::ONNXRuntime::output_node_names_
std::vector< const char * > output_node_names_
Definition: ONNXRuntime.h:60
cms::Ort::ONNXRuntime::output_node_dims_
std::map< std::string, std::vector< int64_t > > output_node_dims_
Definition: ONNXRuntime.h:61
cms::Ort::ONNXRuntime::~ONNXRuntime
~ONNXRuntime()
Definition: ONNXRuntime.cc:81