CMS 3D CMS Logo

TfGraphDefProducer.cc
Go to the documentation of this file.
1 // -*- C++ -*-
2 //
3 // Package: PhysicsTools/TensorFlow
4 // Class: TFGraphDefProducer
5 //
9 //
10 // Original Author: Joona Havukainen
11 // Created: Fri, 24 Jul 2020 08:04:00 GMT
12 //
13 //
14 
15 // system include files
16 #include <memory>
17 
18 // user include files
21 
25 
26 // class declaration
27 
29 public:
31  using ReturnType = std::unique_ptr<TfGraphDefWrapper>;
32 
34 
35  static void fillDescriptions(edm::ConfigurationDescriptions& descriptions);
36 
37 private:
39  // ----------member data ---------------------------
40 };
41 
43  : filename_(iConfig.getParameter<edm::FileInPath>("FileName").fullPath()) {
44  auto componentName = iConfig.getParameter<std::string>("ComponentName");
45  setWhatProduced(this, componentName);
46 }
47 
48 // ------------ method called to produce the data ------------
50  auto* graph = tensorflow::loadGraphDef(filename_);
51  return std::make_unique<TfGraphDefWrapper>(tensorflow::createSession(graph), graph);
52 }
53 
56  desc.add<std::string>("ComponentName", "tfGraphDef");
57  desc.add<edm::FileInPath>("FileName");
58  descriptions.add("tfGraphDefProducer", desc);
59 }
60 
61 //define this as a plug-in
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
auto setWhatProduced(T *iThis, const es::Label &iLabel={})
Definition: ESProducer.h:166
T getParameter(std::string const &) const
Definition: ParameterSet.h:303
ReturnType produce(const TfGraphRecord &)
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:120
TfGraphDefProducer(const edm::ParameterSet &)
std::unique_ptr< TfGraphDefWrapper > ReturnType
#define DEFINE_FWK_EVENTSETUP_MODULE(type)
Definition: ModuleFactory.h:61
Session * createSession()
Definition: TensorFlow.cc:137
void add(std::string const &label, ParameterSetDescription const &psetDescription)
HLT enums.
const std::string filename_