CMS 3D CMS Logo

List of all members | Public Member Functions | Static Public Member Functions | Private Attributes | Static Private Attributes
cms::Ort::ONNXRuntime Class Reference

#include <ONNXRuntime.h>

Public Member Functions

const std::vector< std::string > & getOutputNames () const
 
const std::vector< int64_t > & getOutputShape (const std::string &output_name) const
 
 ONNXRuntime (const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr)
 
 ONNXRuntime (const ONNXRuntime &)=delete
 
ONNXRuntimeoperator= (const ONNXRuntime &)=delete
 
FloatArrays run (const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::vector< int64_t >> &input_shapes={}, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const
 
 ~ONNXRuntime ()
 

Static Public Member Functions

::Ort::SessionOptions defaultSessionOptions (Backend backend=Backend::cpu)
 

Private Attributes

std::map< std::string, std::vector< int64_t > > input_node_dims_
 
std::vector< const char * > input_node_names_
 
std::vector< std::string > input_node_strings_
 
std::map< std::string, std::vector< int64_t > > output_node_dims_
 
std::vector< const char * > output_node_names_
 
std::vector< std::string > output_node_strings_
 
std::unique_ptr<::Ort::Session > session_
 

Static Private Attributes

static const ::Ort::Env env_
 

Detailed Description

Definition at line 30 of file ONNXRuntime.h.

Constructor & Destructor Documentation

◆ ONNXRuntime() [1/2]

cms::Ort::ONNXRuntime::ONNXRuntime ( const std::string &  model_path,
const ::Ort::SessionOptions *  session_options = nullptr 
)

◆ ONNXRuntime() [2/2]

cms::Ort::ONNXRuntime::ONNXRuntime ( const ONNXRuntime )
delete

◆ ~ONNXRuntime()

cms::Ort::ONNXRuntime::~ONNXRuntime ( )

Definition at line 77 of file ONNXRuntime.cc.

77 {}

Member Function Documentation

◆ defaultSessionOptions()

SessionOptions cms::Ort::ONNXRuntime::defaultSessionOptions ( Backend  backend = Backend::cpu)
static

Definition at line 79 of file ONNXRuntime.cc.

References jetsAK4_CHS_cff::backend, cms::Ort::cuda, and AlcaSiPixelAliHarvester0T_cff::options.

79  {
80  SessionOptions sess_opts;
81  sess_opts.SetIntraOpNumThreads(1);
82  if (backend == Backend::cuda) {
83  // https://www.onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html
84  OrtCUDAProviderOptions options;
85  sess_opts.AppendExecutionProvider_CUDA(options);
86  }
87  return sess_opts;
88  }

◆ getOutputNames()

const std::vector< std::string > & cms::Ort::ONNXRuntime::getOutputNames ( ) const

Definition at line 167 of file ONNXRuntime.cc.

References Exception, output_node_strings_, and session_.

167  {
168  if (session_) {
169  return output_node_strings_;
170  } else {
171  throw cms::Exception("RuntimeError") << "Needs to call createSession() first before getting the output names!";
172  }
173  }
std::unique_ptr<::Ort::Session > session_
Definition: ONNXRuntime.h:62
std::vector< std::string > output_node_strings_
Definition: ONNXRuntime.h:68

◆ getOutputShape()

const std::vector< int64_t > & cms::Ort::ONNXRuntime::getOutputShape ( const std::string &  output_name) const

Definition at line 175 of file ONNXRuntime.cc.

References Exception, and output_node_dims_.

175  {
176  auto iter = output_node_dims_.find(output_name);
177  if (iter == output_node_dims_.end()) {
178  throw cms::Exception("RuntimeError") << "Output name " << output_name << " is invalid!";
179  } else {
180  return iter->second;
181  }
182  }
std::map< std::string, std::vector< int64_t > > output_node_dims_
Definition: ONNXRuntime.h:70

◆ operator=()

ONNXRuntime& cms::Ort::ONNXRuntime::operator= ( const ONNXRuntime )
delete

◆ run()

FloatArrays cms::Ort::ONNXRuntime::run ( const std::vector< std::string > &  input_names,
FloatArrays input_values,
const std::vector< std::vector< int64_t >> &  input_shapes = {},
const std::vector< std::string > &  output_names = {},
int64_t  batch_size = 1 
) const

Definition at line 90 of file ONNXRuntime.cc.

References cms::cuda::assert(), Exception, spr::find(), HLT_2022v15_cff::input_names, input_node_dims_, input_node_names_, input_node_strings_, eostools::move(), Skims_PA_cff::name, HLT_2022v15_cff::output_names, output_node_names_, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and session_.

94  {
95  assert(input_names.size() == input_values.size());
96  assert(input_shapes.empty() || input_names.size() == input_shapes.size());
97  assert(batch_size > 0);
98 
99  // create input tensor objects from data values
100  std::vector<Value> input_tensors;
101  auto memory_info = MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
102  for (const auto& name : input_node_strings_) {
103  auto iter = std::find(input_names.begin(), input_names.end(), name);
104  if (iter == input_names.end()) {
105  throw cms::Exception("RuntimeError") << "Input " << name << " is not provided!";
106  }
107  auto input_pos = iter - input_names.begin();
108  auto value = input_values.begin() + input_pos;
109  std::vector<int64_t> input_dims;
110  if (input_shapes.empty()) {
111  input_dims = input_node_dims_.at(name);
112  input_dims[0] = batch_size;
113  } else {
114  input_dims = input_shapes[input_pos];
115  // rely on the given input_shapes to set the batch size
116  if (input_dims[0] != batch_size) {
117  throw cms::Exception("RuntimeError") << "The first element of `input_shapes` (" << input_dims[0]
118  << ") does not match the given `batch_size` (" << batch_size << ")";
119  }
120  }
121  auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1, std::multiplies<int64_t>());
122  if (expected_len != (int64_t)value->size()) {
123  throw cms::Exception("RuntimeError")
124  << "Input array " << name << " has a wrong size of " << value->size() << ", expected " << expected_len;
125  }
126  auto input_tensor =
127  Value::CreateTensor<float>(memory_info, value->data(), value->size(), input_dims.data(), input_dims.size());
128  assert(input_tensor.IsTensor());
129  input_tensors.emplace_back(std::move(input_tensor));
130  }
131 
132  // set output node names; will get all outputs if `output_names` is not provided
133  std::vector<const char*> run_output_node_names;
134  if (output_names.empty()) {
135  run_output_node_names = output_node_names_;
136  } else {
137  for (const auto& name : output_names) {
138  run_output_node_names.push_back(name.c_str());
139  }
140  }
141 
142  // run
143  auto output_tensors = session_->Run(RunOptions{nullptr},
144  input_node_names_.data(),
145  input_tensors.data(),
146  input_tensors.size(),
147  run_output_node_names.data(),
148  run_output_node_names.size());
149 
150  // convert output to floats
152  for (auto& output_tensor : output_tensors) {
153  assert(output_tensor.IsTensor());
154 
155  // get output shape
156  auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
157  auto length = tensor_info.GetElementCount();
158 
159  auto floatarr = output_tensor.GetTensorMutableData<float>();
160  outputs.emplace_back(floatarr, floatarr + length);
161  }
162  assert(outputs.size() == run_output_node_names.size());
163 
164  return outputs;
165  }
std::unique_ptr<::Ort::Session > session_
Definition: ONNXRuntime.h:62
std::map< std::string, std::vector< int64_t > > input_node_dims_
Definition: ONNXRuntime.h:66
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:19
std::vector< std::vector< float > > FloatArrays
Definition: ONNXRuntime.h:23
assert(be >=bs)
Definition: value.py:1
std::vector< const char * > output_node_names_
Definition: ONNXRuntime.h:69
std::vector< std::string > input_node_strings_
Definition: ONNXRuntime.h:64
std::vector< const char * > input_node_names_
Definition: ONNXRuntime.h:65
def move(src, dest)
Definition: eostools.py:511

Member Data Documentation

◆ env_

const Env cms::Ort::ONNXRuntime::env_
staticprivate

Definition at line 61 of file ONNXRuntime.h.

◆ input_node_dims_

std::map<std::string, std::vector<int64_t> > cms::Ort::ONNXRuntime::input_node_dims_
private

Definition at line 66 of file ONNXRuntime.h.

Referenced by run().

◆ input_node_names_

std::vector<const char*> cms::Ort::ONNXRuntime::input_node_names_
private

Definition at line 65 of file ONNXRuntime.h.

Referenced by run().

◆ input_node_strings_

std::vector<std::string> cms::Ort::ONNXRuntime::input_node_strings_
private

Definition at line 64 of file ONNXRuntime.h.

Referenced by run().

◆ output_node_dims_

std::map<std::string, std::vector<int64_t> > cms::Ort::ONNXRuntime::output_node_dims_
private

Definition at line 70 of file ONNXRuntime.h.

Referenced by getOutputShape().

◆ output_node_names_

std::vector<const char*> cms::Ort::ONNXRuntime::output_node_names_
private

Definition at line 69 of file ONNXRuntime.h.

Referenced by run().

◆ output_node_strings_

std::vector<std::string> cms::Ort::ONNXRuntime::output_node_strings_
private

Definition at line 68 of file ONNXRuntime.h.

Referenced by getOutputNames().

◆ session_

std::unique_ptr<::Ort::Session> cms::Ort::ONNXRuntime::session_
private

Definition at line 62 of file ONNXRuntime.h.

Referenced by getOutputNames(), and run().