86 assert(input_shapes.empty() ||
input_names.size() == input_shapes.size());
87 assert(batch_size > 0);
90 std::vector<Value> input_tensors;
91 auto memory_info = MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
98 auto value = input_values.begin() + input_pos;
99 std::vector<int64_t> input_dims;
100 if (input_shapes.empty()) {
102 input_dims[0] = batch_size;
104 input_dims = input_shapes[input_pos];
107 auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1, std::multiplies<int64_t>());
108 if (expected_len != (int64_t)
value->size()) {
110 <<
"Input array " <<
name <<
" has a wrong size of " <<
value->size() <<
", expected " << expected_len;
113 Value::CreateTensor<float>(memory_info,
value->data(),
value->size(), input_dims.data(), input_dims.size());
114 assert(input_tensor.IsTensor());
115 input_tensors.emplace_back(
std::move(input_tensor));
119 std::vector<const char*> run_output_node_names;
120 if (output_names.empty()) {
123 for (
const auto&
name : output_names) {
124 run_output_node_names.push_back(
name.c_str());
129 auto output_tensors =
session_->Run(RunOptions{
nullptr},
131 input_tensors.data(),
132 input_tensors.size(),
133 run_output_node_names.data(),
134 run_output_node_names.size());
138 for (
auto& output_tensor : output_tensors) {
139 assert(output_tensor.IsTensor());
142 auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
143 auto length = tensor_info.GetElementCount();
145 auto floatarr = output_tensor.GetTensorMutableData<
float>();
146 outputs.emplace_back(floatarr, floatarr + length);
148 assert(outputs.size() == run_output_node_names.size());
std::unique_ptr<::Ort::Session > session_
std::map< std::string, std::vector< int64_t > > input_node_dims_
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
std::vector< std::vector< float > > FloatArrays
std::vector< const char * > output_node_names_
std::vector< std::string > input_node_strings_
std::vector< const char * > input_node_names_