CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
ProcTMVA.cc
Go to the documentation of this file.
1 #include <unistd.h>
2 #include <algorithm>
3 #include <iostream>
4 #include <sstream>
5 #include <fstream>
6 #include <cstddef>
7 #include <cstring>
8 #include <cstdio>
9 #include <vector>
10 #include <memory>
11 
12 #include <xercesc/dom/DOM.hpp>
13 
14 #include <TDirectory.h>
15 #include <TTree.h>
16 #include <TFile.h>
17 #include <TCut.h>
18 
19 #include <TMVA/Types.h>
20 #include <TMVA/Factory.h>
21 
23 
27 
33 
34 XERCES_CPP_NAMESPACE_USE
35 
36 using namespace PhysicsTools;
37 
38 namespace { // anonymous
39 
40 class ROOTContextSentinel {
41  public:
42  ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
43  ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
44 
45  private:
46  TDirectory *dir;
47  TFile *file;
48 };
49 
50 class ProcTMVA : public TrainProcessor {
51  public:
53 
54  ProcTMVA(const char *name, const AtomicId *id,
55  MVATrainer *trainer);
56  virtual ~ProcTMVA();
57 
58  virtual void configure(DOMElement *elem) override;
59  virtual Calibration::VarProcessor *getCalibration() const override;
60 
61  virtual void trainBegin() override;
62  virtual void trainData(const std::vector<double> *values,
63  bool target, double weight) override;
64  virtual void trainEnd() override;
65 
66  virtual bool load() override;
67  virtual void cleanup() override;
68 
69  private:
70  void runTMVATrainer();
71 
72  struct Method {
73  TMVA::Types::EMVA type;
76  };
77 
78  std::string getTreeName() const
79  { return trainer->getName() + '_' + (const char*)getName(); }
80 
81  std::string getWeightsFile(const Method &meth, const char *ext) const
82  {
83  return "weights/" + getTreeName() + '_' +
84  meth.name + ".weights." + ext;
85  }
86 
87  enum Iteration {
88  ITER_EXPORT,
89  ITER_DONE
90  } iteration;
91 
92  std::vector<Method> methods;
93  std::vector<std::string> names;
94  std::auto_ptr<TFile> file;
95  TTree *treeSig, *treeBkg;
96  Double_t weight;
97  std::vector<Double_t> vars;
98  bool needCleanup;
99  unsigned long nSignal;
100  unsigned long nBackground;
101  bool doUserTreeSetup;
102  std::string setupCuts; // cut applied by TMVA to signal and background trees
103  std::string setupOptions; // training/test tree TMVA setup options
104 };
105 
106 static ProcTMVA::Registry registry("ProcTMVA");
107 
108 ProcTMVA::ProcTMVA(const char *name, const AtomicId *id,
109  MVATrainer *trainer) :
110  TrainProcessor(name, id, trainer),
111  iteration(ITER_EXPORT), treeSig(0), treeBkg(0), needCleanup(false),
112  doUserTreeSetup(false), setupOptions("SplitMode = Block:!V")
113 {
114 }
115 
116 ProcTMVA::~ProcTMVA()
117 {
118 }
119 
120 void ProcTMVA::configure(DOMElement *elem)
121 {
122  std::vector<SourceVariable*> inputs = getInputs().get();
123 
124  for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
125  iter != inputs.end(); iter++) {
126  std::string name = (const char*)(*iter)->getName();
127 
128  if (std::find(names.begin(), names.end(), name)
129  != names.end()) {
130  for(unsigned i = 1;; i++) {
131  std::ostringstream ss;
132  ss << name << "_" << i;
133  if (std::find(names.begin(), names.end(),
134  ss.str()) == names.end()) {
135  name = ss.str();
136  break;
137  }
138  }
139  }
140 
141  names.push_back(name);
142  }
143 
144  for(DOMNode *node = elem->getFirstChild();
145  node; node = node->getNextSibling()) {
146  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
147  continue;
148 
149  bool isMethod = !std::strcmp(XMLSimpleStr(node->getNodeName()), "method");
150  bool isSetup = !std::strcmp(XMLSimpleStr(node->getNodeName()), "setup");
151 
152  if (!isMethod && !isSetup)
153  throw cms::Exception("ProcTMVA")
154  << "Expected method or setup tag in config section."
155  << std::endl;
156 
157  elem = static_cast<DOMElement*>(node);
158 
159  if (isMethod) {
160  Method method;
161  method.type = TMVA::Types::Instance().GetMethodType(
162  XMLDocument::readAttribute<std::string>(
163  elem, "type").c_str());
164 
165  method.name =
166  XMLDocument::readAttribute<std::string>(
167  elem, "name");
168 
169  method.description =
170  (const char*)XMLSimpleStr(node->getTextContent());
171 
172  methods.push_back(method);
173  } else if (isSetup) {
174  if (doUserTreeSetup)
175  throw cms::Exception("ProcTMVA")
176  << "Multiple appeareances of setup "
177  "tag in config section."
178  << std::endl;
179 
180  doUserTreeSetup = true;
181 
182  setupCuts =
183  XMLDocument::readAttribute<std::string>(
184  elem, "cuts");
185  setupOptions =
186  XMLDocument::readAttribute<std::string>(
187  elem, "options");
188  }
189  }
190 
191  if (!methods.size())
192  throw cms::Exception("ProcTMVA")
193  << "Expected TMVA method in config section."
194  << std::endl;
195 }
196 
197 bool ProcTMVA::load()
198 {
199  bool ok = true;
200  for(std::vector<Method>::const_iterator iter = methods.begin();
201  iter != methods.end(); ++iter) {
202  std::ifstream in(getWeightsFile(*iter, "xml").c_str());
203  if (!in.good()) {
204  ok = false;
205  break;
206  }
207  }
208 
209  if (!ok)
210  return false;
211 
212  iteration = ITER_DONE;
213  trained = true;
214  return true;
215 }
216 
217 static std::size_t getStreamSize(std::ifstream &in)
218 {
219  std::ifstream::pos_type begin = in.tellg();
220  in.seekg(0, std::ios::end);
221  std::ifstream::pos_type end = in.tellg();
222  in.seekg(begin, std::ios::beg);
223 
224  return (std::size_t)(end - begin);
225 }
226 
227 Calibration::VarProcessor *ProcTMVA::getCalibration() const
228 {
230 
231  std::ifstream in(getWeightsFile(methods[0], "xml").c_str(),
232  std::ios::binary | std::ios::in);
233  if (!in.good())
234  throw cms::Exception("ProcTMVA")
235  << "Weights file " << getWeightsFile(methods[0], "xml")
236  << " cannot be opened for reading." << std::endl;
237 
238  std::size_t size = getStreamSize(in) + methods[0].name.size();
239  for(std::vector<std::string>::const_iterator iter = names.begin();
240  iter != names.end(); ++iter)
241  size += iter->size() + 1;
242  size += (size / 32) + 128;
243 
244  std::shared_ptr<char> buffer( new char[size] );
245  ext::omemstream os(buffer.get(), size);
246  /* call dtor of ozs at end */ {
247  ext::ozstream ozs(&os);
248  ozs << methods[0].name << "\n";
249  ozs << names.size() << "\n";
250  for(std::vector<std::string>::const_iterator iter =
251  names.begin();
252  iter != names.end(); ++iter)
253  ozs << *iter << "\n";
254  ozs << in.rdbuf();
255  ozs.flush();
256  }
257  size = os.end() - os.begin();
258  calib->store.resize(size);
259  std::memcpy(&calib->store.front(), os.begin(), size);
260 
261  in.close();
262 
263  calib->method = "ProcTMVA";
264 
265  return calib;
266 }
267 
268 void ProcTMVA::trainBegin()
269 {
270  if (iteration == ITER_EXPORT) {
271  ROOTContextSentinel ctx;
272 
273  file = std::auto_ptr<TFile>(TFile::Open(
274  trainer->trainFileName(this, "root",
275  "input").c_str(),
276  "RECREATE"));
277  if (!file.get())
278  throw cms::Exception("ProcTMVA")
279  << "Could not open ROOT file for writing."
280  << std::endl;
281 
282  file->cd();
283  treeSig = new TTree((getTreeName() + "_sig").c_str(),
284  "MVATrainer signal");
285  treeBkg = new TTree((getTreeName() + "_bkg").c_str(),
286  "MVATrainer background");
287 
288  treeSig->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
289  treeBkg->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
290 
291  vars.resize(names.size());
292 
293  std::vector<Double_t>::iterator pos = vars.begin();
294  for(std::vector<std::string>::const_iterator iter =
295  names.begin(); iter != names.end(); iter++, pos++) {
296  treeSig->Branch(iter->c_str(), &*pos,
297  (*iter + "/D").c_str());
298  treeBkg->Branch(iter->c_str(), &*pos,
299  (*iter + "/D").c_str());
300  }
301 
302  nSignal = nBackground = 0;
303  }
304 }
305 
306 void ProcTMVA::trainData(const std::vector<double> *values,
307  bool target, double weight)
308 {
309  if (iteration != ITER_EXPORT)
310  return;
311 
312  this->weight = weight;
313  for(unsigned int i = 0; i < vars.size(); i++, values++)
314  vars[i] = values->front();
315 
316  if (target) {
317  treeSig->Fill();
318  nSignal++;
319  } else {
320  treeBkg->Fill();
321  nBackground++;
322  }
323 }
324 
325 void ProcTMVA::runTMVATrainer()
326 {
327  needCleanup = true;
328 
329  if (nSignal < 1 || nBackground < 1)
330  throw cms::Exception("ProcTMVA")
331  << "Not going to run TMVA: "
332  "No signal (" << nSignal << ") or background ("
333  << nBackground << ") events!" << std::endl;
334 
335  std::auto_ptr<TFile> file(TFile::Open(
336  trainer->trainFileName(this, "root", "output").c_str(),
337  "RECREATE"));
338  if (!file.get())
339  throw cms::Exception("ProcTMVA")
340  << "Could not open TMVA ROOT file for writing."
341  << std::endl;
342 
343  std::auto_ptr<TMVA::Factory> factory(
344  new TMVA::Factory(getTreeName().c_str(), file.get(), ""));
345 
346  factory->SetInputTrees(treeSig, treeBkg);
347 
348  for(std::vector<std::string>::const_iterator iter = names.begin();
349  iter != names.end(); iter++)
350  factory->AddVariable(iter->c_str(), 'D');
351 
352  factory->SetWeightExpression("__WEIGHT__");
353 
354  if (doUserTreeSetup)
355  factory->PrepareTrainingAndTestTree(
356  setupCuts.c_str(), setupOptions);
357  else
358  factory->PrepareTrainingAndTestTree(
359  "", 0, 0, 0, 0,
360  "SplitMode=Block:!V");
361 
362  for(std::vector<Method>::const_iterator iter = methods.begin();
363  iter != methods.end(); ++iter)
364  factory->BookMethod(iter->type, iter->name, iter->description);
365 
366  factory->TrainAllMethods();
367  factory->TestAllMethods();
368  factory->EvaluateAllMethods();
369 
370  factory.release(); // ROOT seems to take care of destruction?!
371 
372  file->Close();
373 
374  printf("TMVA training factory completed\n");
375 }
376 
377 void ProcTMVA::trainEnd()
378 {
379  switch(iteration) {
380  case ITER_EXPORT:
381  /* ROOT context-safe */ {
382  ROOTContextSentinel ctx;
383  file->cd();
384  treeSig->Write();
385  treeBkg->Write();
386 
387  file->Close();
388  file.reset();
389  file = std::auto_ptr<TFile>(TFile::Open(
390  trainer->trainFileName(this, "root",
391  "input").c_str()));
392  if (!file.get())
393  throw cms::Exception("ProcTMVA")
394  << "Could not open ROOT file for "
395  "reading." << std::endl;
396  treeSig = dynamic_cast<TTree*>(
397  file->Get((getTreeName() + "_sig").c_str()));
398  treeBkg = dynamic_cast<TTree*>(
399  file->Get((getTreeName() + "_bkg").c_str()));
400 
401  runTMVATrainer();
402 
403  file->Close();
404  treeSig = 0;
405  treeBkg = 0;
406  file.reset();
407  }
408  vars.clear();
409 
410  iteration = ITER_DONE;
411  trained = true;
412  break;
413  default:
414  /* shut up */;
415  }
416 }
417 
418 void ProcTMVA::cleanup()
419 {
420  if (!needCleanup)
421  return;
422 
423  std::remove(trainer->trainFileName(this, "root", "input").c_str());
424  std::remove(trainer->trainFileName(this, "root", "output").c_str());
425  for(std::vector<Method>::const_iterator iter = methods.begin();
426  iter != methods.end(); ++iter) {
427  std::remove(getWeightsFile(*iter, "xml").c_str());
428  std::remove(getWeightsFile(*iter, "root").c_str());
429  }
430  rmdir("weights");
431 }
432 
433 } // anonymous namespace
434 
435 MVA_TRAINER_DEFINE_PLUGIN(ProcTMVA);
type
Definition: HCALResponse.h:21
int i
Definition: DBlmapReader.cc:9
static const HistoName names[]
static void cleanup(const Factory::MakerMap::value_type &v)
Definition: Factory.cc:12
template to generate a registry singleton for a type.
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:7
tuple node
Definition: Node.py:50
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:31
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
tuple iteration
Definition: align_cfg.py:5
def load
Definition: svgfig.py:546
#define end
Definition: vmac.h:37
tuple description
Definition: idDealer.py:66
#define begin
Definition: vmac.h:30
dbl *** dir
Definition: mlp_gen.cc:35
volatile std::atomic< bool > shutdown_flag false
int weight
Definition: histoStyle.py:50
static Interceptor::Registry registry("Interceptor")
std::vector< unsigned char > store
Definition: MVAComputer.h:228
tuple size
Write out results.
#define MVA_TRAINER_DEFINE_PLUGIN(T)