CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
GBRForestWriter.cc
Go to the documentation of this file.
2 
4 
7 
9 
10 #include "TMVA/ClassifierFactory.h"
11 #include "TMVA/Event.h"
12 #include "TMVA/Factory.h"
13 #include "TMVA/MethodBase.h"
14 #include "TMVA/MethodBDT.h"
15 #include "TMVA/Reader.h"
16 #include "TMVA/Tools.h"
17 
18 #include <TFile.h>
19 
21  : moduleLabel_(cfg.getParameter<std::string>("@module_label"))
22 {
23  edm::VParameterSet cfgJobs = cfg.getParameter<edm::VParameterSet>("jobs");
24  for ( edm::VParameterSet::const_iterator cfgJob = cfgJobs.begin();
25  cfgJob != cfgJobs.end(); ++cfgJob ) {
26  jobEntryType* job = new jobEntryType(*cfgJob);
27  jobs_.push_back(job);
28  }
29 }
30 
32 {
33  for ( std::vector<jobEntryType*>::iterator it = jobs_.begin();
34  it != jobs_.end(); ++it ) {
35  delete (*it);
36  }
37 }
38 
40 {
41 
42  for ( std::vector<jobEntryType*>::iterator job = jobs_.begin();
43  job != jobs_.end(); ++job ) {
44  std::map<std::string, const GBRForest*> gbrForests; // key = name
45  for ( std::vector<categoryEntryType*>::iterator category = (*job)->categories_.begin();
46  category != (*job)->categories_.end(); ++category ) {
47  const GBRForest* gbrForest = nullptr;
48  if ( (*category)->inputFileType_ == categoryEntryType::kXML ) {
49  TMVA::Tools::Instance();
50  TMVA::Reader* mvaReader = new TMVA::Reader("!V:!Silent");
51  std::vector<Float_t> dummyVariables;
52  for ( vstring::const_iterator inputVariable = (*category)->inputVariables_.begin();
53  inputVariable != (*category)->inputVariables_.end(); ++inputVariable ) {
54  dummyVariables.push_back(0.);
55  mvaReader->AddVariable(inputVariable->data(), &dummyVariables.back());
56  }
57  for ( vstring::const_iterator spectatorVariable = (*category)->spectatorVariables_.begin();
58  spectatorVariable != (*category)->spectatorVariables_.end(); ++spectatorVariable ) {
59  dummyVariables.push_back(0.);
60  mvaReader->AddSpectator(spectatorVariable->data(), &dummyVariables.back());
61  }
62  mvaReader->BookMVA((*category)->gbrForestName_.data(), (*category)->inputFileName_.data());
63  TMVA::MethodBDT* bdt = dynamic_cast<TMVA::MethodBDT*>(mvaReader->FindMVA((*category)->gbrForestName_.data()));
64  if ( !bdt )
65  throw cms::Exception("GBRForestWriter")
66  << "Failed to load MVA = " << (*category)->gbrForestName_.data() << " from file = " << (*category)->inputFileName_ << " !!\n";
67  gbrForest = new GBRForest(bdt);
68  delete mvaReader;
69  TMVA::Tools::DestroyInstance();
70  } else if ( (*category)->inputFileType_ == categoryEntryType::kGBRForest ) {
71  TFile* inputFile = new TFile((*category)->inputFileName_.data());
72  //gbrForest = dynamic_cast<GBRForest*>(inputFile->Get((*category)->gbrForestName_.data())); // CV: dynamic_cast<GBRForest*> fails for some reason ?!
73  gbrForest = (GBRForest*)inputFile->Get((*category)->gbrForestName_.data());
74  delete inputFile;
75  }
76  if ( !gbrForest )
77  throw cms::Exception("GBRForestWriter")
78  << " Failed to load GBRForest = " << (*category)->gbrForestName_.data() << " from file = " << (*category)->inputFileName_ << " !!\n";
79  gbrForests[(*category)->gbrForestName_] = gbrForest;
80  }
81  if ( (*job)->outputFileType_ == jobEntryType::kGBRForest ) {
82  TFile* outputFile = new TFile((*job)->outputFileName_.data(), "RECREATE");
83 
84  for ( std::map<std::string, const GBRForest*>::iterator gbrForest = gbrForests.begin();
85  gbrForest != gbrForests.end(); ++gbrForest ) {
86  outputFile->WriteObject(gbrForest->second, gbrForest->first.data());
87  }
88  delete outputFile;
89  } else if ( (*job)->outputFileType_ == jobEntryType::kSQLLite ) {
91  if ( !dbService.isAvailable() )
92  throw cms::Exception("GBRForestWriter")
93  << " Failed to access PoolDBOutputService !!\n";
94 
95  for ( std::map<std::string, const GBRForest*>::iterator gbrForest = gbrForests.begin();
96  gbrForest != gbrForests.end(); ++gbrForest ) {
97  std::string outputRecord = (*job)->outputRecord_;
98  if ( gbrForests.size() > 1 ) outputRecord.append("_").append(gbrForest->first);
99  dbService->writeOne(gbrForest->second, dbService->beginOfTime(), outputRecord);
100  }
101  }
102 
103  //gbrforest deletion
104  for ( std::map<std::string, const GBRForest*>::iterator gbrForest = gbrForests.begin();
105  gbrForest != gbrForests.end(); ++gbrForest ) {
106  delete gbrForest->second;
107  }
108 
109  }
110 
111 }
112 
114 
T getParameter(std::string const &) const
tuple cfg
Definition: looper.py:259
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:17
std::vector< ParameterSet > VParameterSet
Definition: ParameterSet.h:33
GBRForestWriter(const edm::ParameterSet &)
bool isAvailable() const
Definition: Service.h:46
void writeOne(T *payload, Time_t time, const std::string &recordName, bool withlogging=false)
virtual void analyze(const edm::Event &, const edm::EventSetup &)
std::vector< jobEntryType * > jobs_
moduleLabel_(iConfig.getParameter< string >("@module_label"))