CMS 3D CMS Logo

ONNXRuntime.cc
Go to the documentation of this file.
1 /*
2  * ONNXRuntime.cc
3  *
4  * Created on: Jun 28, 2019
5  * Author: hqu
6  */
7 
9 
10 #include <cassert>
11 #include <iostream>
12 #include <algorithm>
13 #include <numeric>
14 #include <functional>
17 
18 namespace cms::Ort {
19 
20  using namespace ::Ort;
21 
22  const Env ONNXRuntime::env_(ORT_LOGGING_LEVEL_WARNING, "");
23 
24  ONNXRuntime::ONNXRuntime(const std::string& model_path, const SessionOptions* session_options) {
25  // create session
26  if (session_options) {
27  session_.reset(new Session(env_, model_path.c_str(), *session_options));
28  } else {
29  SessionOptions sess_opts;
30  sess_opts.SetIntraOpNumThreads(1);
31  session_.reset(new Session(env_, model_path.c_str(), sess_opts));
32  }
33  AllocatorWithDefaultOptions allocator;
34 
35  // get input names and shapes
36  size_t num_input_nodes = session_->GetInputCount();
37  input_node_strings_.resize(num_input_nodes);
38  input_node_names_.resize(num_input_nodes);
39  input_node_dims_.clear();
40 
41  for (size_t i = 0; i < num_input_nodes; i++) {
42  // get input node names
43  std::string input_name(session_->GetInputName(i, allocator));
44  input_node_strings_[i] = input_name;
46 
47  // get input shapes
48  auto type_info = session_->GetInputTypeInfo(i);
49  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
50  size_t num_dims = tensor_info.GetDimensionsCount();
51  input_node_dims_[input_name].resize(num_dims);
52  tensor_info.GetDimensions(input_node_dims_[input_name].data(), num_dims);
53 
54  // set the batch size to 1 by default
55  input_node_dims_[input_name].at(0) = 1;
56  }
57 
58  size_t num_output_nodes = session_->GetOutputCount();
59  output_node_strings_.resize(num_output_nodes);
60  output_node_names_.resize(num_output_nodes);
61  output_node_dims_.clear();
62 
63  for (size_t i = 0; i < num_output_nodes; i++) {
64  // get output node names
65  std::string output_name(session_->GetOutputName(i, allocator));
66  output_node_strings_[i] = output_name;
68 
69  // get output node types
70  auto type_info = session_->GetOutputTypeInfo(i);
71  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
72  size_t num_dims = tensor_info.GetDimensionsCount();
73  output_node_dims_[output_name].resize(num_dims);
74  tensor_info.GetDimensions(output_node_dims_[output_name].data(), num_dims);
75 
76  // the 0th dim depends on the batch size
77  output_node_dims_[output_name].at(0) = -1;
78  }
79  }
80 
82 
83  FloatArrays ONNXRuntime::run(const std::vector<std::string>& input_names,
84  FloatArrays& input_values,
85  const std::vector<std::string>& output_names,
86  int64_t batch_size) const {
87  assert(input_names.size() == input_values.size());
88  assert(batch_size > 0);
89 
90  // create input tensor objects from data values
91  std::vector<Value> input_tensors;
92  auto memory_info = MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
93  for (const auto& name : input_node_strings_) {
94  auto iter = std::find(input_names.begin(), input_names.end(), name);
95  if (iter == input_names.end()) {
96  throw cms::Exception("RuntimeError") << "Input " << name << " is not provided!";
97  }
98  auto value = input_values.begin() + (iter - input_names.begin());
99  auto input_dims = input_node_dims_.at(name);
100  input_dims[0] = batch_size;
101  auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1, std::multiplies<int64_t>());
102  if (expected_len != (int64_t)value->size()) {
103  throw cms::Exception("RuntimeError")
104  << "Input array " << name << " has a wrong size of " << value->size() << ", expected " << expected_len;
105  }
106  auto input_tensor =
107  Value::CreateTensor<float>(memory_info, value->data(), value->size(), input_dims.data(), input_dims.size());
108  assert(input_tensor.IsTensor());
109  input_tensors.emplace_back(std::move(input_tensor));
110  }
111 
112  // set output node names; will get all outputs if `output_names` is not provided
113  std::vector<const char*> run_output_node_names;
114  if (output_names.empty()) {
115  run_output_node_names = output_node_names_;
116  } else {
117  for (const auto& name : output_names) {
118  run_output_node_names.push_back(name.c_str());
119  }
120  }
121 
122  // run
123  auto output_tensors = session_->Run(RunOptions{nullptr},
124  input_node_names_.data(),
125  input_tensors.data(),
126  input_tensors.size(),
127  run_output_node_names.data(),
128  run_output_node_names.size());
129 
130  // convert output to floats
132  for (auto& output_tensor : output_tensors) {
133  assert(output_tensor.IsTensor());
134 
135  // get output shape
136  auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
137  auto length = tensor_info.GetElementCount();
138 
139  auto floatarr = output_tensor.GetTensorMutableData<float>();
140  outputs.emplace_back(floatarr, floatarr + length);
141  }
142  assert(outputs.size() == run_output_node_names.size());
143 
144  return outputs;
145  }
146 
147  const std::vector<std::string>& ONNXRuntime::getOutputNames() const {
148  if (session_) {
149  return output_node_strings_;
150  } else {
151  throw cms::Exception("RuntimeError") << "Needs to call createSession() first before getting the output names!";
152  }
153  }
154 
155  const std::vector<int64_t>& ONNXRuntime::getOutputShape(const std::string& output_name) const {
156  auto iter = output_node_dims_.find(output_name);
157  if (iter == output_node_dims_.end()) {
158  throw cms::Exception("RuntimeError") << "Output name " << output_name << " is invalid!";
159  } else {
160  return iter->second;
161  }
162  }
163 
164 } /* namespace cms::Ort */
mps_fire.i
i
Definition: mps_fire.py:355
cms::Ort::ONNXRuntime::session_
std::unique_ptr<::Ort::Session > session_
Definition: ONNXRuntime.h:53
pfParticleNet_cff.model_path
model_path
Definition: pfParticleNet_cff.py:15
PatBasicFWLiteJetAnalyzer_Selector_cfg.outputs
outputs
Definition: PatBasicFWLiteJetAnalyzer_Selector_cfg.py:48
cms::Ort::ONNXRuntime::env_
static const ::Ort::Env env_
Definition: ONNXRuntime.h:52
cms::cuda::assert
assert(be >=bs)
cms::Ort::ONNXRuntime::ONNXRuntime
ONNXRuntime(const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr)
cms::Ort::ONNXRuntime::run
FloatArrays run(const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const
Definition: ONNXRuntime.cc:83
spr::find
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:19
cms::Ort::ONNXRuntime::getOutputNames
const std::vector< std::string > & getOutputNames() const
Definition: ONNXRuntime.cc:147
ONNXRuntime.h
cms::Ort
Definition: ONNXRuntime.h:21
Session
cms::Ort::ONNXRuntime::input_node_strings_
std::vector< std::string > input_node_strings_
Definition: ONNXRuntime.h:55
cms::Ort::ONNXRuntime::input_node_names_
std::vector< const char * > input_node_names_
Definition: ONNXRuntime.h:56
cms::Ort::ONNXRuntime::output_node_strings_
std::vector< std::string > output_node_strings_
Definition: ONNXRuntime.h:59
cms::Ort::ONNXRuntime::input_node_dims_
std::map< std::string, std::vector< int64_t > > input_node_dims_
Definition: ONNXRuntime.h:57
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
thread_safety_macros.h
value
Definition: value.py:1
cms::Ort::ONNXRuntime::getOutputShape
const std::vector< int64_t > & getOutputShape(const std::string &output_name) const
Definition: ONNXRuntime.cc:155
cms::Ort::FloatArrays
std::vector< std::vector< float > > FloatArrays
Definition: ONNXRuntime.h:23
eostools.move
def move(src, dest)
Definition: eostools.py:511
Exception
Definition: hltDiff.cc:246
Skims_PA_cff.name
name
Definition: Skims_PA_cff.py:17
Exception.h
data
char data[epos_bytes_allocation]
Definition: EPOS_Wrapper.h:79
cms::Ort::ONNXRuntime::output_node_names_
std::vector< const char * > output_node_names_
Definition: ONNXRuntime.h:60
pfParticleNetPreprocessParams_cfi.input_names
input_names
Definition: pfParticleNetPreprocessParams_cfi.py:4
cms::Ort::ONNXRuntime::output_node_dims_
std::map< std::string, std::vector< int64_t > > output_node_dims_
Definition: ONNXRuntime.h:61
cms::Ort::ONNXRuntime::~ONNXRuntime
~ONNXRuntime()
Definition: ONNXRuntime.cc:81