CMS 3D CMS Logo

List of all members | Public Member Functions | Private Attributes
BaseMVACache Class Reference

#include <BaseMVAValueMapProducer.h>

Public Member Functions

 BaseMVACache (const std::string &model_path, const std::string &backend, const bool disableONNXGraphOpt)
 
const cms::Ort::ONNXRuntimegetONNXSession () const
 
tensorflow::Session * getTFSession () const
 
 ~BaseMVACache ()
 

Private Attributes

std::shared_ptr< tensorflow::GraphDef > graph_
 
std::unique_ptr< cms::Ort::ONNXRuntimeort_
 
tensorflow::Session * tf_session_ = nullptr
 

Detailed Description

Definition at line 56 of file BaseMVAValueMapProducer.h.

Constructor & Destructor Documentation

◆ BaseMVACache()

BaseMVACache::BaseMVACache ( const std::string &  model_path,
const std::string &  backend,
const bool  disableONNXGraphOpt 
)
inline

Definition at line 58 of file BaseMVAValueMapProducer.h.

References jetsAK4_CHS_cff::backend, tensorflow::createSession(), cms::Ort::ONNXRuntime::defaultSessionOptions(), graph_, tensorflow::loadGraphDef(), HLT_2023v12_cff::model_path, ort_, and tf_session_.

58  {
59  if (backend == "TF") {
62  } else if (backend == "ONNX") {
63  if (disableONNXGraphOpt) {
64  Ort::SessionOptions sess_opts;
66  sess_opts.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL);
67  ort_ = std::make_unique<cms::Ort::ONNXRuntime>(model_path, &sess_opts);
68  } else {
69  ort_ = std::make_unique<cms::Ort::ONNXRuntime>(model_path);
70  }
71  }
72  }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:84
::Ort::SessionOptions defaultSessionOptions(Backend backend=Backend::cpu)
Definition: ONNXRuntime.cc:79
std::shared_ptr< tensorflow::GraphDef > graph_
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:67
std::unique_ptr< cms::Ort::ONNXRuntime > ort_
tensorflow::Session * tf_session_

◆ ~BaseMVACache()

BaseMVACache::~BaseMVACache ( )
inline

Definition at line 73 of file BaseMVAValueMapProducer.h.

References tensorflow::closeSession(), and tf_session_.

bool closeSession(Session *&session)
Definition: TensorFlow.cc:197
tensorflow::Session * tf_session_

Member Function Documentation

◆ getONNXSession()

const cms::Ort::ONNXRuntime& BaseMVACache::getONNXSession ( ) const
inline

Definition at line 76 of file BaseMVAValueMapProducer.h.

References ort_.

76 { return *ort_; }
std::unique_ptr< cms::Ort::ONNXRuntime > ort_

◆ getTFSession()

tensorflow::Session* BaseMVACache::getTFSession ( ) const
inline

Definition at line 75 of file BaseMVAValueMapProducer.h.

References tf_session_.

75 { return tf_session_; }
tensorflow::Session * tf_session_

Member Data Documentation

◆ graph_

std::shared_ptr<tensorflow::GraphDef> BaseMVACache::graph_
private

Definition at line 79 of file BaseMVAValueMapProducer.h.

Referenced by BaseMVACache().

◆ ort_

std::unique_ptr<cms::Ort::ONNXRuntime> BaseMVACache::ort_
private

Definition at line 81 of file BaseMVAValueMapProducer.h.

Referenced by BaseMVACache(), and getONNXSession().

◆ tf_session_

tensorflow::Session* BaseMVACache::tf_session_ = nullptr
private

Definition at line 80 of file BaseMVAValueMapProducer.h.

Referenced by BaseMVACache(), getTFSession(), and ~BaseMVACache().