CMS 3D CMS Logo

Classes | Typedefs | Functions | Variables
tensorflow Namespace Reference

Classes

class  NTSession
 
class  NTSessionFactory
 
class  NTSessionRegistrar
 
class  TBBSession
 
class  TBBSessionFactory
 
class  TBBSessionRegistrar
 

Typedefs

typedef std::pair< std::string, Tensor > NamedTensor
 
typedef std::vector< NamedTensorNamedTensorList
 

Functions

bool closeSession (Session *&session)
 
SessioncreateSession (SessionOptions &sessionOptions)
 
SessioncreateSession (int nThreads=1)
 
SessioncreateSession (MetaGraphDef *metaGraph, const std::string &exportDir, SessionOptions &sessionOptions)
 
SessioncreateSession (MetaGraphDef *metaGraph, const std::string &exportDir, int nThreads=1)
 
SessioncreateSession (GraphDef *graphDef, SessionOptions &sessionOptions)
 
SessioncreateSession (GraphDef *graphDef, int nThreads=1)
 
GraphDef * loadGraphDef (const std::string &pbFile)
 
MetaGraphDef * loadMetaGraph (const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
 
MetaGraphDef * loadMetaGraph (const std::string &exportDir, const std::string &tag=kSavedModelTagServe, int nThreads=1)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, const std::vector< std::string > &targetNodes, std::vector< Tensor > *outputs)
 
void run (Session *session, const std::vector< std::string > &inputNames, const std::vector< Tensor > &inputTensors, const std::vector< std::string > &outputNames, const std::vector< std::string > &targetNodes, std::vector< Tensor > *outputs)
 
void run (Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs)
 
void run (Session *session, const std::vector< std::string > &inputNames, const std::vector< Tensor > &inputTensors, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs)
 
void setLogging (const std::string &level="3")
 
void setThreading (SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
 

Variables

static NTSessionRegistrar registrar
 
static TBBSessionRegistrar registrar
 

Typedef Documentation

typedef std::pair<std::string, Tensor> tensorflow::NamedTensor

Definition at line 25 of file TensorFlow.h.

Definition at line 26 of file TensorFlow.h.

Function Documentation

bool tensorflow::closeSession ( Session *&  session)

Definition at line 193 of file TensorFlow.cc.

References dataDML::session, btagGenBb_cfi::Status, and mps_update::status.

Referenced by DTOccupancyTestML::dqmEndLuminosityBlock(), DeepDoubleXTFJetTagsProducer::~DeepDoubleXTFJetTagsProducer(), DeepFlavourTFJetTagsProducer::~DeepFlavourTFJetTagsProducer(), and deep_tau::DeepTauCache::~DeepTauCache().

194 {
195  if (session == nullptr)
196  {
197  return true;
198  }
199 
200  // close and delete the session
201  Status status = session->Close();
202  delete session;
203 
204  // reset the pointer
205  session = nullptr;
206 
207  return status.ok();
208 }
Session * tensorflow::createSession ( SessionOptions &  sessionOptions)

Definition at line 87 of file TensorFlow.cc.

References Exception, dataDML::session, btagGenBb_cfi::Status, and mps_update::status.

Referenced by BaseMVAValueMapProducer< pat::Jet >::BaseMVAValueMapProducer(), createSession(), DeepDoubleXTFJetTagsProducer::DeepDoubleXTFJetTagsProducer(), DeepFlavourTFJetTagsProducer::DeepFlavourTFJetTagsProducer(), deep_tau::DeepTauCache::DeepTauCache(), and DTOccupancyTestML::dqmEndLuminosityBlock().

88 {
89  // objects to create the session
90  Status status;
91 
92  // create a new, empty session
93  Session* session = nullptr;
94  status = NewSession(sessionOptions, &session);
95  if (!status.ok())
96  {
97  throw cms::Exception("InvalidSession")
98  << "error while creating session: " << status.ToString();
99  }
100 
101  return session;
102 }
Session * tensorflow::createSession ( int  nThreads = 1)

Definition at line 104 of file TensorFlow.cc.

References createSession(), and setThreading().

105 {
106  // create session options and set thread options
107  SessionOptions sessionOptions;
108  setThreading(sessionOptions, nThreads);
109 
110  return createSession(sessionOptions);
111 }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:87
void setThreading(SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
Definition: TensorFlow.cc:19
Session * tensorflow::createSession ( MetaGraphDef *  metaGraph,
const std::string &  exportDir,
SessionOptions &  sessionOptions 
)

Definition at line 113 of file TensorFlow.cc.

References createSession(), Default, Exception, convertSQLiteXML::ok, dataDML::session, btagGenBb_cfi::Status, mps_update::status, and AlCaHLTBitMon_QueryRunRegistry::string.

115 {
116  Session* session = createSession(sessionOptions);
117 
118  // add the graph def from the meta graph
119  Status status;
120  status = session->Create(metaGraph->graph_def());
121  if (!status.ok())
122  {
123  throw cms::Exception("InvalidSession")
124  << "error while attaching meta graph to session: " << status.ToString();
125  }
126 
127  // restore variables using the variable and index files in the export directory
128  // first, find names and paths
129  std::string varFileTensorName = metaGraph->saver_def().filename_tensor_name();
130  std::string restoreOpName = metaGraph->saver_def().restore_op_name();
131  std::string varDir = io::JoinPath(exportDir, kSavedModelVariablesDirectory);
132  std::string indexFile = io::JoinPath(varDir, MetaFilename(kSavedModelVariablesFilename));
133  std::string varFile = io::JoinPath(varDir, kSavedModelVariablesFilename);
134 
135  // when the index file is missing, there's nothing to do
136  if (!Env::Default()->FileExists(indexFile).ok())
137  {
138  return session;
139  }
140 
141  // create a tensor to store the variable file
142  Tensor varFileTensor(DT_STRING, TensorShape({}));
143  varFileTensor.scalar<std::string>()() = varFile;
144 
145  // run the restore op
146  status = session->Run({ { varFileTensorName, varFileTensor } }, {}, { restoreOpName }, nullptr);
147  if (!status.ok())
148  {
149  throw cms::Exception("InvalidSession")
150  << "error while restoring variables in session: " << status.ToString();
151  }
152 
153  return session;
154 }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:87
#define Default
Definition: vmac.h:110
Session * tensorflow::createSession ( MetaGraphDef *  metaGraph,
const std::string &  exportDir,
int  nThreads = 1 
)

Definition at line 156 of file TensorFlow.cc.

References createSession(), and setThreading().

157 {
158  // create session options and set thread options
159  SessionOptions sessionOptions;
160  setThreading(sessionOptions, nThreads);
161 
162  return createSession(metaGraph, exportDir, sessionOptions);
163 }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:87
void setThreading(SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
Definition: TensorFlow.cc:19
Session * tensorflow::createSession ( GraphDef *  graphDef,
SessionOptions &  sessionOptions 
)

Definition at line 165 of file TensorFlow.cc.

References createSession(), Exception, dataDML::session, btagGenBb_cfi::Status, and mps_update::status.

166 {
167  // create a new, empty session
168  Session* session = createSession(sessionOptions);
169 
170  // add the graph def
171  Status status;
172  status = session->Create(*graphDef);
173 
174  // check for success
175  if (!status.ok())
176  {
177  throw cms::Exception("InvalidSession")
178  << "error while attaching graph def to session: " << status.ToString();
179  }
180 
181  return session;
182 }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:87
Session * tensorflow::createSession ( GraphDef *  graphDef,
int  nThreads = 1 
)

Definition at line 184 of file TensorFlow.cc.

References createSession(), and setThreading().

185 {
186  // create session options and set thread options
187  SessionOptions sessionOptions;
188  setThreading(sessionOptions, nThreads);
189 
190  return createSession(graphDef, sessionOptions);
191 }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:87
void setThreading(SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
Definition: TensorFlow.cc:19
GraphDef * tensorflow::loadGraphDef ( const std::string &  pbFile)

Definition at line 68 of file TensorFlow.cc.

References Default, Exception, btagGenBb_cfi::Status, and mps_update::status.

Referenced by BaseMVAValueMapProducer< pat::Jet >::BaseMVAValueMapProducer(), deep_tau::DeepTauCache::DeepTauCache(), DTOccupancyTestML::dqmEndLuminosityBlock(), DeepDoubleXTFJetTagsProducer::initializeGlobalCache(), and DeepFlavourTFJetTagsProducer::initializeGlobalCache().

69 {
70  // objects to load the graph
71  Status status;
72 
73  // load it
74  GraphDef* graphDef = new GraphDef();
75  status = ReadBinaryProto(Env::Default(), pbFile, graphDef);
76 
77  // check for success
78  if (!status.ok())
79  {
80  throw cms::Exception("InvalidGraphDef")
81  << "error while loading graph def: " << status.ToString();
82  }
83 
84  return graphDef;
85 }
#define Default
Definition: vmac.h:110
MetaGraphDef * tensorflow::loadMetaGraph ( const std::string &  exportDir,
const std::string &  tag,
SessionOptions &  sessionOptions 
)

Definition at line 39 of file TensorFlow.cc.

References Exception, btagGenBb_cfi::Status, and mps_update::status.

Referenced by loadMetaGraph().

41 {
42  // objects to load the graph
43  Status status;
44  RunOptions runOptions;
45  SavedModelBundle bundle;
46 
47  // load the model
48  status = LoadSavedModel(sessionOptions, runOptions, exportDir, { tag }, &bundle);
49  if (!status.ok())
50  {
51  throw cms::Exception("InvalidMetaGraph")
52  << "error while loading meta graph: " << status.ToString();
53  }
54 
55  // return a copy of the graph
56  return new MetaGraphDef(bundle.meta_graph_def);
57 }
MetaGraphDef * tensorflow::loadMetaGraph ( const std::string &  exportDir,
const std::string &  tag = kSavedModelTagServe,
int  nThreads = 1 
)

Definition at line 59 of file TensorFlow.cc.

References loadMetaGraph(), and setThreading().

60 {
61  // create session options and set thread options
62  SessionOptions sessionOptions;
63  setThreading(sessionOptions, nThreads);
64 
65  return loadMetaGraph(exportDir, tag, sessionOptions);
66 }
MetaGraphDef * loadMetaGraph(const std::string &exportDir, const std::string &tag, SessionOptions &sessionOptions)
Definition: TensorFlow.cc:39
void setThreading(SessionOptions &sessionOptions, int nThreads, const std::string &singleThreadPool="no_threads")
Definition: TensorFlow.cc:19
void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
const std::vector< std::string > &  targetNodes,
std::vector< Tensor > *  outputs 
)

Definition at line 210 of file TensorFlow.cc.

References Exception, btagGenBb_cfi::Status, and mps_update::status.

Referenced by DPFIsolation::getPredictions(), DeepTauId::getPredictions(), DeepDoubleXTFJetTagsProducer::produce(), DeepFlavourTFJetTagsProducer::produce(), BaseMVAValueMapProducer< T >::produce(), run(), and DTOccupancyTestML::runOccupancyTest().

213 {
214  if (session == nullptr)
215  {
216  throw cms::Exception("InvalidSession") << "cannot run empty session";
217  }
218 
219  // run and check the status
220  Status status = session->Run(inputs, outputNames, targetNodes, outputs);
221  if (!status.ok())
222  {
223  throw cms::Exception("InvalidRun")
224  << "error while running session: " << status.ToString();
225  }
226 }
void tensorflow::run ( Session session,
const std::vector< std::string > &  inputNames,
const std::vector< Tensor > &  inputTensors,
const std::vector< std::string > &  outputNames,
const std::vector< std::string > &  targetNodes,
std::vector< Tensor > *  outputs 
)

Definition at line 228 of file TensorFlow.cc.

References Exception, mps_fire::i, PatBasicFWLiteJetAnalyzer_Selector_cfg::inputs, and run().

231 {
232  if (inputNames.size() != inputTensors.size())
233  {
234  throw cms::Exception("InvalidInput") << "numbers of input names and tensors not equal";
235  }
236 
238  for (size_t i = 0; i < inputNames.size(); i++)
239  {
240  inputs.push_back(NamedTensor(inputNames[i], inputTensors[i]));
241  }
242 
243  run(session, inputs, outputNames, targetNodes, outputs);
244 }
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:26
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:25
void tensorflow::run ( Session session,
const NamedTensorList inputs,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs 
)
void tensorflow::run ( Session session,
const std::vector< std::string > &  inputNames,
const std::vector< Tensor > &  inputTensors,
const std::vector< std::string > &  outputNames,
std::vector< Tensor > *  outputs 
)

Definition at line 252 of file TensorFlow.cc.

References PatBasicFWLiteJetAnalyzer_Selector_cfg::outputs, and run().

void tensorflow::setLogging ( const std::string &  level = "3")
void tensorflow::setThreading ( SessionOptions &  sessionOptions,
int  nThreads,
const std::string &  singleThreadPool = "no_threads" 
)

Definition at line 19 of file TensorFlow.cc.

References Exception, and jets_cff::singleThreadPool.

Referenced by BaseMVAValueMapProducer< pat::Jet >::BaseMVAValueMapProducer(), createSession(), DeepDoubleXTFJetTagsProducer::DeepDoubleXTFJetTagsProducer(), DeepFlavourTFJetTagsProducer::DeepFlavourTFJetTagsProducer(), deep_tau::DeepTauCache::DeepTauCache(), and loadMetaGraph().

21 {
22  // set number of threads used for intra and inter operation communication
23  sessionOptions.config.set_intra_op_parallelism_threads(nThreads);
24  sessionOptions.config.set_inter_op_parallelism_threads(nThreads);
25 
26  // when exactly one thread is requested use a custom thread pool
27  if (nThreads == 1 && !singleThreadPool.empty())
28  {
29  // check for known thread pools
30  if (singleThreadPool != "no_threads" && singleThreadPool != "tbb")
31  {
32  throw cms::Exception("UnknownThreadPool")
33  << "thread pool '" << singleThreadPool << "' unknown, use 'no_threads' or 'tbb'";
34  }
35  sessionOptions.target = singleThreadPool;
36  }
37 }
singleThreadPool
Definition: jets_cff.py:298

Variable Documentation

NTSessionRegistrar tensorflow::registrar
static

Definition at line 172 of file NTSession.cc.

TBBSessionRegistrar tensorflow::registrar
static

Definition at line 173 of file TBBSession.cc.