CMS 3D CMS Logo

ProcTMVA.cc
Go to the documentation of this file.
1 #include <unistd.h>
2 #include <algorithm>
3 #include <iostream>
4 #include <sstream>
5 #include <fstream>
6 #include <cstddef>
7 #include <cstring>
8 #include <cstdio>
9 #include <vector>
10 #include <memory>
11 
12 #include <xercesc/dom/DOM.hpp>
13 
14 #include <TDirectory.h>
15 #include <TTree.h>
16 #include <TFile.h>
17 #include <TCut.h>
18 
19 #include <TMVA/Types.h>
20 #include <TMVA/Factory.h>
21 #include <TMVA/DataLoader.h>
22 
24 
28 
34 
36 
37 using namespace PhysicsTools;
38 
39 namespace { // anonymous
40 
41 class ROOTContextSentinel {
42  public:
43  ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
44  ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
45 
46  private:
47  TDirectory *dir;
48  TFile *file;
49 };
50 
51 class ProcTMVA : public TrainProcessor {
52  public:
54 
55  ProcTMVA(const char *name, const AtomicId *id,
56  MVATrainer *trainer);
57  ~ProcTMVA() override;
58 
59  void configure(DOMElement *elem) override;
60  Calibration::VarProcessor *getCalibration() const override;
61 
62  void trainBegin() override;
63  void trainData(const std::vector<double> *values,
64  bool target, double weight) override;
65  void trainEnd() override;
66 
67  bool load() override;
68  void cleanup() override;
69 
70  private:
71  void runTMVATrainer();
72 
73  struct Method {
74  TMVA::Types::EMVA type;
77  };
78 
79  std::string getTreeName() const
80  { return trainer->getName() + '_' + (const char*)getName(); }
81 
82  std::string getWeightsFile(const Method &meth, const char *ext) const
83  {
84  return "weights/" + getTreeName() + '_' +
85  meth.name + ".weights." + ext;
86  }
87 
88  enum Iteration {
89  ITER_EXPORT,
90  ITER_DONE
91  } iteration;
92 
93  std::vector<Method> methods;
94  std::vector<std::string> names;
95  std::auto_ptr<TFile> file;
96  TTree *treeSig, *treeBkg;
97  Double_t weight;
98  std::vector<Double_t> vars;
99  bool needCleanup;
100  unsigned long nSignal;
101  unsigned long nBackground;
102  bool doUserTreeSetup;
103  std::string setupCuts; // cut applied by TMVA to signal and background trees
104  std::string setupOptions; // training/test tree TMVA setup options
105 };
106 
107 ProcTMVA::Registry registry("ProcTMVA");
108 
109 ProcTMVA::ProcTMVA(const char *name, const AtomicId *id,
110  MVATrainer *trainer) :
111  TrainProcessor(name, id, trainer),
112  iteration(ITER_EXPORT), treeSig(nullptr), treeBkg(nullptr), needCleanup(false),
113  doUserTreeSetup(false), setupOptions("SplitMode = Block:!V")
114 {
115 }
116 
117 ProcTMVA::~ProcTMVA()
118 {
119 }
120 
121 void ProcTMVA::configure(DOMElement *elem)
122 {
123  std::vector<SourceVariable*> inputs = getInputs().get();
124 
125  for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
126  iter != inputs.end(); iter++) {
127  std::string name = (const char*)(*iter)->getName();
128 
129  if (std::find(names.begin(), names.end(), name)
130  != names.end()) {
131  for(unsigned i = 1;; i++) {
132  std::ostringstream ss;
133  ss << name << "_" << i;
134  if (std::find(names.begin(), names.end(),
135  ss.str()) == names.end()) {
136  name = ss.str();
137  break;
138  }
139  }
140  }
141 
142  names.push_back(name);
143  }
144 
145  for(DOMNode *node = elem->getFirstChild();
146  node; node = node->getNextSibling()) {
147  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
148  continue;
149 
150  bool isMethod = !std::strcmp(XMLSimpleStr(node->getNodeName()), "method");
151  bool isSetup = !std::strcmp(XMLSimpleStr(node->getNodeName()), "setup");
152 
153  if (!isMethod && !isSetup)
154  throw cms::Exception("ProcTMVA")
155  << "Expected method or setup tag in config section."
156  << std::endl;
157 
158  elem = static_cast<DOMElement*>(node);
159 
160  if (isMethod) {
161  Method method;
162  method.type = TMVA::Types::Instance().GetMethodType(
163  XMLDocument::readAttribute<std::string>(
164  elem, "type").c_str());
165 
166  method.name =
167  XMLDocument::readAttribute<std::string>(
168  elem, "name");
169 
170  method.description =
171  (const char*)XMLSimpleStr(node->getTextContent());
172 
173  methods.push_back(method);
174  } else if (isSetup) {
175  if (doUserTreeSetup)
176  throw cms::Exception("ProcTMVA")
177  << "Multiple appeareances of setup "
178  "tag in config section."
179  << std::endl;
180 
181  doUserTreeSetup = true;
182 
183  setupCuts =
184  XMLDocument::readAttribute<std::string>(
185  elem, "cuts");
186  setupOptions =
187  XMLDocument::readAttribute<std::string>(
188  elem, "options");
189  }
190  }
191 
192  if (methods.empty())
193  throw cms::Exception("ProcTMVA")
194  << "Expected TMVA method in config section."
195  << std::endl;
196 }
197 
198 bool ProcTMVA::load()
199 {
200  bool ok = true;
201  for(std::vector<Method>::const_iterator iter = methods.begin();
202  iter != methods.end(); ++iter) {
203  std::ifstream in(getWeightsFile(*iter, "xml").c_str());
204  if (!in.good()) {
205  ok = false;
206  break;
207  }
208  }
209 
210  if (!ok)
211  return false;
212 
213  iteration = ITER_DONE;
214  trained = true;
215  return true;
216 }
217 
218 std::size_t getStreamSize(std::ifstream &in)
219 {
220  std::ifstream::pos_type begin = in.tellg();
221  in.seekg(0, std::ios::end);
222  std::ifstream::pos_type end = in.tellg();
223  in.seekg(begin, std::ios::beg);
224 
225  return (std::size_t)(end - begin);
226 }
227 
228 Calibration::VarProcessor *ProcTMVA::getCalibration() const
229 {
231 
232  std::ifstream in(getWeightsFile(methods[0], "xml").c_str(),
233  std::ios::binary | std::ios::in);
234  if (!in.good())
235  throw cms::Exception("ProcTMVA")
236  << "Weights file " << getWeightsFile(methods[0], "xml")
237  << " cannot be opened for reading." << std::endl;
238 
239  std::size_t size = getStreamSize(in) + methods[0].name.size();
240  for(std::vector<std::string>::const_iterator iter = names.begin();
241  iter != names.end(); ++iter)
242  size += iter->size() + 1;
243  size += (size / 32) + 128;
244 
245  std::shared_ptr<char> buffer( new char[size] );
246  ext::omemstream os(buffer.get(), size);
247  /* call dtor of ozs at end */ {
248  ext::ozstream ozs(&os);
249  ozs << methods[0].name << "\n";
250  ozs << names.size() << "\n";
251  for(std::vector<std::string>::const_iterator iter =
252  names.begin();
253  iter != names.end(); ++iter)
254  ozs << *iter << "\n";
255  ozs << in.rdbuf();
256  ozs.flush();
257  }
258  size = os.end() - os.begin();
259  calib->store.resize(size);
260  std::memcpy(&calib->store.front(), os.begin(), size);
261 
262  in.close();
263 
264  calib->method = "ProcTMVA";
265 
266  return calib;
267 }
268 
269 void ProcTMVA::trainBegin()
270 {
271  if (iteration == ITER_EXPORT) {
272  ROOTContextSentinel ctx;
273 
274  file = std::auto_ptr<TFile>(TFile::Open(
275  trainer->trainFileName(this, "root",
276  "input").c_str(),
277  "RECREATE"));
278  if (!file.get())
279  throw cms::Exception("ProcTMVA")
280  << "Could not open ROOT file for writing."
281  << std::endl;
282 
283  file->cd();
284  treeSig = new TTree((getTreeName() + "_sig").c_str(),
285  "MVATrainer signal");
286  treeBkg = new TTree((getTreeName() + "_bkg").c_str(),
287  "MVATrainer background");
288 
289  treeSig->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
290  treeBkg->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
291 
292  vars.resize(names.size());
293 
294  std::vector<Double_t>::iterator pos = vars.begin();
295  for(std::vector<std::string>::const_iterator iter =
296  names.begin(); iter != names.end(); iter++, pos++) {
297  treeSig->Branch(iter->c_str(), &*pos,
298  (*iter + "/D").c_str());
299  treeBkg->Branch(iter->c_str(), &*pos,
300  (*iter + "/D").c_str());
301  }
302 
303  nSignal = nBackground = 0;
304  }
305 }
306 
307 void ProcTMVA::trainData(const std::vector<double> *values,
308  bool target, double weight)
309 {
310  if (iteration != ITER_EXPORT)
311  return;
312 
313  this->weight = weight;
314  for(unsigned int i = 0; i < vars.size(); i++, values++)
315  vars[i] = values->front();
316 
317  if (target) {
318  treeSig->Fill();
319  nSignal++;
320  } else {
321  treeBkg->Fill();
322  nBackground++;
323  }
324 }
325 
326 void ProcTMVA::runTMVATrainer()
327 {
328  needCleanup = true;
329 
330  if (nSignal < 1 || nBackground < 1)
331  throw cms::Exception("ProcTMVA")
332  << "Not going to run TMVA: "
333  "No signal (" << nSignal << ") or background ("
334  << nBackground << ") events!" << std::endl;
335 
336  std::auto_ptr<TFile> file(TFile::Open(
337  trainer->trainFileName(this, "root", "output").c_str(),
338  "RECREATE"));
339  if (!file.get())
340  throw cms::Exception("ProcTMVA")
341  << "Could not open TMVA ROOT file for writing."
342  << std::endl;
343 
344  std::unique_ptr<TMVA::Factory> factory(
345  new TMVA::Factory(getTreeName().c_str(), file.get(), ""));
346 
347  std::unique_ptr<TMVA::DataLoader> loader(new TMVA::DataLoader("ProcTMVA"));
348 
349  loader->SetInputTrees(treeSig, treeBkg);
350 
351  for(std::vector<std::string>::const_iterator iter = names.begin();
352  iter != names.end(); iter++)
353  loader->AddVariable(iter->c_str(), 'D');
354 
355  loader->SetWeightExpression("__WEIGHT__");
356 
357  if (doUserTreeSetup) {
358  loader->PrepareTrainingAndTestTree(
359  setupCuts.c_str(), setupOptions);
360  } else {
361  loader->PrepareTrainingAndTestTree(
362  "", 0, 0, 0, 0,
363  "SplitMode=Block:!V");
364  }
365 
366  for(std::vector<Method>::const_iterator iter = methods.begin();
367  iter != methods.end(); ++iter)
368  factory->BookMethod(loader.get(), iter->type, iter->name, iter->description);
369 
370  factory->TrainAllMethods();
371  factory->TestAllMethods();
372  factory->EvaluateAllMethods();
373 
374  factory.release(); // ROOT seems to take care of destruction?!
375  loader.release();
376 
377  file->Close();
378 
379  printf("TMVA training factory completed\n");
380 }
381 
382 void ProcTMVA::trainEnd()
383 {
384  switch(iteration) {
385  case ITER_EXPORT:
386  /* ROOT context-safe */ {
387  ROOTContextSentinel ctx;
388  file->cd();
389  treeSig->Write();
390  treeBkg->Write();
391 
392  file->Close();
393  file.reset();
394  file = std::auto_ptr<TFile>(TFile::Open(
395  trainer->trainFileName(this, "root",
396  "input").c_str()));
397  if (!file.get())
398  throw cms::Exception("ProcTMVA")
399  << "Could not open ROOT file for "
400  "reading." << std::endl;
401  treeSig = dynamic_cast<TTree*>(
402  file->Get((getTreeName() + "_sig").c_str()));
403  treeBkg = dynamic_cast<TTree*>(
404  file->Get((getTreeName() + "_bkg").c_str()));
405 
406  runTMVATrainer();
407 
408  file->Close();
409  treeSig = nullptr;
410  treeBkg = nullptr;
411  file.reset();
412  }
413  vars.clear();
414 
415  iteration = ITER_DONE;
416  trained = true;
417  break;
418  default:
419  /* shut up */;
420  }
421 }
422 
423 void ProcTMVA::cleanup()
424 {
425  if (!needCleanup)
426  return;
427 
428  std::remove(trainer->trainFileName(this, "root", "input").c_str());
429  std::remove(trainer->trainFileName(this, "root", "output").c_str());
430  for(std::vector<Method>::const_iterator iter = methods.begin();
431  iter != methods.end(); ++iter) {
432  std::remove(getWeightsFile(*iter, "xml").c_str());
433  std::remove(getWeightsFile(*iter, "root").c_str());
434  }
435  rmdir("weights");
436 }
437 
438 } // anonymous namespace
439 
440 MVA_TRAINER_DEFINE_PLUGIN(ProcTMVA);
size
Write out results.
type
Definition: HCALResponse.h:21
static void cleanup(const Factory::MakerMap::value_type &v)
Definition: Factory.cc:12
Definition: weight.py:1
template to generate a registry singleton for a type.
#define nullptr
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:20
const std::string names[nVars_]
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:31
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
#define end
Definition: vmac.h:39
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:18
def load(fileName)
Definition: svgfig.py:546
def remove(d, key, TELL=False)
Definition: MatrixUtil.py:211
#define begin
Definition: vmac.h:32
dbl *** dir
Definition: mlp_gen.cc:35
vars
Definition: DeepTauId.cc:77
static Interceptor::Registry registry("Interceptor")
Definition: memstream.h:15
std::vector< unsigned char > store
Definition: MVAComputer.h:200
#define MVA_TRAINER_DEFINE_PLUGIN(T)