#include <ONNXRuntime.h>
Public Member Functions | |
const std::vector< std::string > & | getOutputNames () const |
const std::vector< int64_t > & | getOutputShape (const std::string &output_name) const |
ONNXRuntime (const ONNXRuntime &)=delete | |
ONNXRuntime (const std::string &model_path, const ::Ort::SessionOptions *session_options=nullptr) | |
ONNXRuntime & | operator= (const ONNXRuntime &)=delete |
FloatArrays | run (const std::vector< std::string > &input_names, FloatArrays &input_values, const std::vector< std::vector< int64_t >> &input_shapes={}, const std::vector< std::string > &output_names={}, int64_t batch_size=1) const |
~ONNXRuntime () | |
Private Attributes | |
std::map< std::string, std::vector< int64_t > > | input_node_dims_ |
std::vector< const char * > | input_node_names_ |
std::vector< std::string > | input_node_strings_ |
std::map< std::string, std::vector< int64_t > > | output_node_dims_ |
std::vector< const char * > | output_node_names_ |
std::vector< std::string > | output_node_strings_ |
std::unique_ptr<::Ort::Session > | session_ |
Static Private Attributes | |
static const ::Ort::Env | env_ |
Definition at line 25 of file ONNXRuntime.h.
cms::Ort::ONNXRuntime::ONNXRuntime | ( | const std::string & | model_path, |
const ::Ort::SessionOptions * | session_options = nullptr |
||
) |
|
delete |
cms::Ort::ONNXRuntime::~ONNXRuntime | ( | ) |
Definition at line 78 of file ONNXRuntime.cc.
const std::vector< std::string > & cms::Ort::ONNXRuntime::getOutputNames | ( | ) | const |
Definition at line 153 of file ONNXRuntime.cc.
References Exception, output_node_strings_, and session_.
Referenced by TrackQuality::setTrackQuality().
const std::vector< int64_t > & cms::Ort::ONNXRuntime::getOutputShape | ( | const std::string & | output_name | ) | const |
|
delete |
FloatArrays cms::Ort::ONNXRuntime::run | ( | const std::vector< std::string > & | input_names, |
FloatArrays & | input_values, | ||
const std::vector< std::vector< int64_t >> & | input_shapes = {} , |
||
const std::vector< std::string > & | output_names = {} , |
||
int64_t | batch_size = 1 |
||
) | const |
Definition at line 80 of file ONNXRuntime.cc.
References cms::cuda::assert(), Exception, spr::find(), pfDeepBoostedJetPreprocessParams_cfi::input_names, input_node_dims_, input_node_names_, input_node_strings_, eostools::move(), Skims_PA_cff::name, output_node_names_, PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and session_.
|
staticprivate |
Definition at line 54 of file ONNXRuntime.h.
|
private |
Definition at line 59 of file ONNXRuntime.h.
Referenced by run().
|
private |
Definition at line 58 of file ONNXRuntime.h.
Referenced by run().
|
private |
Definition at line 57 of file ONNXRuntime.h.
Referenced by run().
|
private |
Definition at line 63 of file ONNXRuntime.h.
Referenced by getOutputShape().
|
private |
Definition at line 62 of file ONNXRuntime.h.
Referenced by run().
|
private |
Definition at line 61 of file ONNXRuntime.h.
Referenced by getOutputNames().
|
private |
Definition at line 55 of file ONNXRuntime.h.
Referenced by getOutputNames(), and run().