20 using namespace ::Ort;
26 if (session_options) {
29 SessionOptions sess_opts;
30 sess_opts.SetIntraOpNumThreads(1);
33 AllocatorWithDefaultOptions allocator;
36 size_t num_input_nodes =
session_->GetInputCount();
41 for (
size_t i = 0;
i < num_input_nodes;
i++) {
48 auto type_info =
session_->GetInputTypeInfo(
i);
49 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
50 size_t num_dims = tensor_info.GetDimensionsCount();
58 size_t num_output_nodes =
session_->GetOutputCount();
63 for (
size_t i = 0;
i < num_output_nodes;
i++) {
70 auto type_info =
session_->GetOutputTypeInfo(
i);
71 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
72 size_t num_dims = tensor_info.GetDimensionsCount();
85 const std::vector<std::string>& output_names,
86 int64_t batch_size)
const {
91 std::vector<Value> input_tensors;
92 auto memory_info = MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
100 input_dims[0] = batch_size;
101 auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1, std::multiplies<int64_t>());
102 if (expected_len != (int64_t)
value->size()) {
104 <<
"Input array " <<
name <<
" has a wrong size of " <<
value->size() <<
", expected " << expected_len;
107 Value::CreateTensor<float>(memory_info,
value->data(),
value->size(), input_dims.data(), input_dims.size());
108 assert(input_tensor.IsTensor());
109 input_tensors.emplace_back(
std::move(input_tensor));
113 std::vector<const char*> run_output_node_names;
114 if (output_names.empty()) {
117 for (
const auto&
name : output_names) {
118 run_output_node_names.push_back(
name.c_str());
123 auto output_tensors =
session_->Run(RunOptions{
nullptr},
125 input_tensors.data(),
126 input_tensors.size(),
127 run_output_node_names.data(),
128 run_output_node_names.size());
132 for (
auto& output_tensor : output_tensors) {
133 assert(output_tensor.IsTensor());
136 auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
137 auto length = tensor_info.GetElementCount();
139 auto floatarr = output_tensor.GetTensorMutableData<
float>();
140 outputs.emplace_back(floatarr, floatarr + length);
151 throw cms::Exception(
"RuntimeError") <<
"Needs to call createSession() first before getting the output names!";
158 throw cms::Exception(
"RuntimeError") <<
"Output name " << output_name <<
" is invalid!";