CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
List of all members | Public Member Functions | Private Attributes
BaseMVACache Class Reference

#include <BaseMVAValueMapProducer.h>

Public Member Functions

 BaseMVACache (const std::string &model_path, const std::string &backend)
 
const cms::Ort::ONNXRuntimegetONNXSession () const
 
tensorflow::Session * getTFSession () const
 
 ~BaseMVACache ()
 

Private Attributes

std::shared_ptr
< tensorflow::GraphDef > 
graph_
 
std::unique_ptr
< cms::Ort::ONNXRuntime
ort_
 
tensorflow::Session * tf_session_ = nullptr
 

Detailed Description

Definition at line 56 of file BaseMVAValueMapProducer.h.

Constructor & Destructor Documentation

BaseMVACache::BaseMVACache ( const std::string &  model_path,
const std::string &  backend 
)
inline

Definition at line 58 of file BaseMVAValueMapProducer.h.

References tensorflow::createSession(), graph_, tensorflow::loadGraphDef(), ort_, and tf_session_.

58  {
59  if (backend == "TF") {
60  graph_.reset(tensorflow::loadGraphDef(model_path));
62  } else if (backend == "ONNX") {
63  ort_ = std::make_unique<cms::Ort::ONNXRuntime>(model_path);
64  }
65  }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
std::shared_ptr< tensorflow::GraphDef > graph_
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
std::unique_ptr< cms::Ort::ONNXRuntime > ort_
tensorflow::Session * tf_session_
BaseMVACache::~BaseMVACache ( )
inline

Definition at line 66 of file BaseMVAValueMapProducer.h.

References tensorflow::closeSession(), and tf_session_.

bool closeSession(Session *&session)
Definition: TensorFlow.cc:198
tensorflow::Session * tf_session_

Member Function Documentation

const cms::Ort::ONNXRuntime& BaseMVACache::getONNXSession ( ) const
inline

Definition at line 69 of file BaseMVAValueMapProducer.h.

References ort_.

69 { return *ort_; }
std::unique_ptr< cms::Ort::ONNXRuntime > ort_
tensorflow::Session* BaseMVACache::getTFSession ( ) const
inline

Definition at line 68 of file BaseMVAValueMapProducer.h.

References tf_session_.

68 { return tf_session_; }
tensorflow::Session * tf_session_

Member Data Documentation

std::shared_ptr<tensorflow::GraphDef> BaseMVACache::graph_
private

Definition at line 72 of file BaseMVAValueMapProducer.h.

Referenced by BaseMVACache().

std::unique_ptr<cms::Ort::ONNXRuntime> BaseMVACache::ort_
private

Definition at line 74 of file BaseMVAValueMapProducer.h.

Referenced by BaseMVACache(), and getONNXSession().

tensorflow::Session* BaseMVACache::tf_session_ = nullptr
private

Definition at line 73 of file BaseMVAValueMapProducer.h.

Referenced by BaseMVACache(), getTFSession(), and ~BaseMVACache().