CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
ProcTMVA.cc
Go to the documentation of this file.
1 #include <unistd.h>
2 #include <algorithm>
3 #include <iostream>
4 #include <sstream>
5 #include <fstream>
6 #include <cstddef>
7 #include <cstring>
8 #include <cstdio>
9 #include <vector>
10 #include <memory>
11 
12 #include <xercesc/dom/DOM.hpp>
13 
14 #include <TDirectory.h>
15 #include <TTree.h>
16 #include <TFile.h>
17 #include <TCut.h>
18 
19 #include <TMVA/Types.h>
20 #include <TMVA/Factory.h>
21 
23 
27 
33 
34 XERCES_CPP_NAMESPACE_USE
35 
36 using namespace PhysicsTools;
37 
38 namespace { // anonymous
39 
40 class ROOTContextSentinel {
41  public:
42  ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
43  ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
44 
45  private:
46  TDirectory *dir;
47  TFile *file;
48 };
49 
50 class ProcTMVA : public TrainProcessor {
51  public:
53 
54  ProcTMVA(const char *name, const AtomicId *id,
55  MVATrainer *trainer);
56  virtual ~ProcTMVA();
57 
58  virtual void configure(DOMElement *elem);
59  virtual Calibration::VarProcessor *getCalibration() const;
60 
61  virtual void trainBegin();
62  virtual void trainData(const std::vector<double> *values,
63  bool target, double weight);
64  virtual void trainEnd();
65 
66  virtual bool load();
67  virtual void cleanup();
68 
69  private:
70  void runTMVATrainer();
71 
72  struct Method {
73  TMVA::Types::EMVA type;
74  std::string name;
75  std::string description;
76  };
77 
78  std::string getTreeName() const
79  { return trainer->getName() + '_' + (const char*)getName(); }
80 
81  std::string getWeightsFile(const Method &meth, const char *ext) const
82  {
83  return "weights/" + getTreeName() + '_' +
84  meth.name + ".weights." + ext;
85  }
86 
87  enum Iteration {
88  ITER_EXPORT,
89  ITER_DONE
90  } iteration;
91 
92  std::vector<Method> methods;
93  std::vector<std::string> names;
94  std::auto_ptr<TFile> file;
95  TTree *treeSig, *treeBkg;
96  Double_t weight;
97  std::vector<Double_t> vars;
98  bool needCleanup;
99  unsigned long nSignal;
100  unsigned long nBackground;
101  bool doUserTreeSetup;
102  std::string setupCuts; // cut applied by TMVA to signal and background trees
103  std::string setupOptions; // training/test tree TMVA setup options
104 };
105 
106 static ProcTMVA::Registry registry("ProcTMVA");
107 
108 ProcTMVA::ProcTMVA(const char *name, const AtomicId *id,
109  MVATrainer *trainer) :
110  TrainProcessor(name, id, trainer),
111  iteration(ITER_EXPORT), treeSig(0), treeBkg(0), needCleanup(false),
112  doUserTreeSetup(false), setupOptions("SplitMode = Block:!V")
113 {
114 }
115 
116 ProcTMVA::~ProcTMVA()
117 {
118 }
119 
120 void ProcTMVA::configure(DOMElement *elem)
121 {
122  std::vector<SourceVariable*> inputs = getInputs().get();
123 
124  for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
125  iter != inputs.end(); iter++) {
126  std::string name = (const char*)(*iter)->getName();
127 
128  if (std::find(names.begin(), names.end(), name)
129  != names.end()) {
130  for(unsigned i = 1;; i++) {
131  std::ostringstream ss;
132  ss << name << "_" << i;
133  if (std::find(names.begin(), names.end(),
134  ss.str()) == names.end()) {
135  name == ss.str();
136  break;
137  }
138  }
139  }
140 
141  names.push_back(name);
142  }
143 
144  for(DOMNode *node = elem->getFirstChild();
145  node; node = node->getNextSibling()) {
146  if (node->getNodeType() != DOMNode::ELEMENT_NODE)
147  continue;
148 
149  bool isMethod = !std::strcmp(XMLSimpleStr(node->getNodeName()), "method");
150  bool isSetup = !std::strcmp(XMLSimpleStr(node->getNodeName()), "setup");
151 
152  if (!isMethod && !isSetup)
153  throw cms::Exception("ProcTMVA")
154  << "Expected method or setup tag in config section."
155  << std::endl;
156 
157  elem = static_cast<DOMElement*>(node);
158 
159  if (isMethod) {
160  Method method;
161  method.type = TMVA::Types::Instance().GetMethodType(
162  XMLDocument::readAttribute<std::string>(
163  elem, "type").c_str());
164 
165  method.name =
166  XMLDocument::readAttribute<std::string>(
167  elem, "name");
168 
169  method.description =
170  (const char*)XMLSimpleStr(node->getTextContent());
171 
172  methods.push_back(method);
173  } else if (isSetup) {
174  if (doUserTreeSetup)
175  throw cms::Exception("ProcTMVA")
176  << "Multiple appeareances of setup "
177  "tag in config section."
178  << std::endl;
179 
180  doUserTreeSetup = true;
181 
182  setupCuts =
183  XMLDocument::readAttribute<std::string>(
184  elem, "cuts");
185  setupOptions =
186  XMLDocument::readAttribute<std::string>(
187  elem, "options");
188  }
189  }
190 
191  if (!methods.size())
192  throw cms::Exception("ProcTMVA")
193  << "Expected TMVA method in config section."
194  << std::endl;
195 }
196 
197 bool ProcTMVA::load()
198 {
199  bool ok = true;
200  for(std::vector<Method>::const_iterator iter = methods.begin();
201  iter != methods.end(); ++iter) {
202  std::ifstream in(getWeightsFile(*iter, "xml").c_str());
203  if (!in.good()) {
204  ok = false;
205  break;
206  }
207  }
208 
209  if (!ok)
210  return false;
211 
212  iteration = ITER_DONE;
213  trained = true;
214  return true;
215 }
216 
217 static std::size_t getStreamSize(std::ifstream &in)
218 {
219  std::ifstream::pos_type begin = in.tellg();
220  in.seekg(0, std::ios::end);
221  std::ifstream::pos_type end = in.tellg();
222  in.seekg(begin, std::ios::beg);
223 
224  return (std::size_t)(end - begin);
225 }
226 
227 Calibration::VarProcessor *ProcTMVA::getCalibration() const
228 {
230 
231  std::ifstream in(getWeightsFile(methods[0], "xml").c_str(),
232  std::ios::binary | std::ios::in);
233  if (!in.good())
234  throw cms::Exception("ProcTMVA")
235  << "Weights file " << getWeightsFile(methods[0], "xml")
236  << " cannot be opened for reading." << std::endl;
237 
238  std::size_t size = getStreamSize(in) + methods[0].name.size();
239  for(std::vector<std::string>::const_iterator iter = names.begin();
240  iter != names.end(); ++iter)
241  size += iter->size() + 1;
242  size += (size / 32) + 128;
243 
244  char *buffer = 0;
245  try {
246  buffer = new char[size];
247  ext::omemstream os(buffer, size);
248  /* call dtor of ozs at end */ {
249  ext::ozstream ozs(&os);
250  ozs << methods[0].name << "\n";
251  ozs << names.size() << "\n";
252  for(std::vector<std::string>::const_iterator iter =
253  names.begin();
254  iter != names.end(); ++iter)
255  ozs << *iter << "\n";
256  ozs << in.rdbuf();
257  ozs.flush();
258  }
259  size = os.end() - os.begin();
260  calib->store.resize(size);
261  std::memcpy(&calib->store.front(), os.begin(), size);
262  } catch(...) {
263  delete[] buffer;
264  throw;
265  }
266  delete[] buffer;
267  in.close();
268 
269  calib->method = "ProcTMVA";
270 
271  return calib;
272 }
273 
274 void ProcTMVA::trainBegin()
275 {
276  if (iteration == ITER_EXPORT) {
277  ROOTContextSentinel ctx;
278 
279  file = std::auto_ptr<TFile>(TFile::Open(
280  trainer->trainFileName(this, "root",
281  "input").c_str(),
282  "RECREATE"));
283  if (!file.get())
284  throw cms::Exception("ProcTMVA")
285  << "Could not open ROOT file for writing."
286  << std::endl;
287 
288  file->cd();
289  treeSig = new TTree((getTreeName() + "_sig").c_str(),
290  "MVATrainer signal");
291  treeBkg = new TTree((getTreeName() + "_bkg").c_str(),
292  "MVATrainer background");
293 
294  treeSig->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
295  treeBkg->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
296 
297  vars.resize(names.size());
298 
299  std::vector<Double_t>::iterator pos = vars.begin();
300  for(std::vector<std::string>::const_iterator iter =
301  names.begin(); iter != names.end(); iter++, pos++) {
302  treeSig->Branch(iter->c_str(), &*pos,
303  (*iter + "/D").c_str());
304  treeBkg->Branch(iter->c_str(), &*pos,
305  (*iter + "/D").c_str());
306  }
307 
308  nSignal = nBackground = 0;
309  }
310 }
311 
312 void ProcTMVA::trainData(const std::vector<double> *values,
313  bool target, double weight)
314 {
315  if (iteration != ITER_EXPORT)
316  return;
317 
318  this->weight = weight;
319  for(unsigned int i = 0; i < vars.size(); i++, values++)
320  vars[i] = values->front();
321 
322  if (target) {
323  treeSig->Fill();
324  nSignal++;
325  } else {
326  treeBkg->Fill();
327  nBackground++;
328  }
329 }
330 
331 void ProcTMVA::runTMVATrainer()
332 {
333  needCleanup = true;
334 
335  if (nSignal < 1 || nBackground < 1)
336  throw cms::Exception("ProcTMVA")
337  << "Not going to run TMVA: "
338  "No signal (" << nSignal << ") or background ("
339  << nBackground << ") events!" << std::endl;
340 
341  std::auto_ptr<TFile> file(TFile::Open(
342  trainer->trainFileName(this, "root", "output").c_str(),
343  "RECREATE"));
344  if (!file.get())
345  throw cms::Exception("ProcTMVA")
346  << "Could not open TMVA ROOT file for writing."
347  << std::endl;
348 
349  std::auto_ptr<TMVA::Factory> factory(
350  new TMVA::Factory(getTreeName().c_str(), file.get(), ""));
351 
352  factory->SetInputTrees(treeSig, treeBkg);
353 
354  for(std::vector<std::string>::const_iterator iter = names.begin();
355  iter != names.end(); iter++)
356  factory->AddVariable(iter->c_str(), 'D');
357 
358  factory->SetWeightExpression("__WEIGHT__");
359 
360  if (doUserTreeSetup)
361  factory->PrepareTrainingAndTestTree(
362  setupCuts.c_str(), setupOptions);
363  else
364  factory->PrepareTrainingAndTestTree(
365  "", 0, 0, 0, 0,
366  "SplitMode=Block:!V");
367 
368  for(std::vector<Method>::const_iterator iter = methods.begin();
369  iter != methods.end(); ++iter)
370  factory->BookMethod(iter->type, iter->name, iter->description);
371 
372  factory->TrainAllMethods();
373  factory->TestAllMethods();
374  factory->EvaluateAllMethods();
375 
376  factory.release(); // ROOT seems to take care of destruction?!
377 
378  file->Close();
379 
380  printf("TMVA training factory completed\n");
381 }
382 
383 void ProcTMVA::trainEnd()
384 {
385  switch(iteration) {
386  case ITER_EXPORT:
387  /* ROOT context-safe */ {
388  ROOTContextSentinel ctx;
389  file->cd();
390  treeSig->Write();
391  treeBkg->Write();
392 
393  file->Close();
394  file.reset();
395  file = std::auto_ptr<TFile>(TFile::Open(
396  trainer->trainFileName(this, "root",
397  "input").c_str()));
398  if (!file.get())
399  throw cms::Exception("ProcTMVA")
400  << "Could not open ROOT file for "
401  "reading." << std::endl;
402  treeSig = dynamic_cast<TTree*>(
403  file->Get((getTreeName() + "_sig").c_str()));
404  treeBkg = dynamic_cast<TTree*>(
405  file->Get((getTreeName() + "_bkg").c_str()));
406 
407  runTMVATrainer();
408 
409  file->Close();
410  treeSig = 0;
411  treeBkg = 0;
412  file.reset();
413  }
414  vars.clear();
415 
416  iteration = ITER_DONE;
417  trained = true;
418  break;
419  default:
420  /* shut up */;
421  }
422 }
423 
424 void ProcTMVA::cleanup()
425 {
426  if (!needCleanup)
427  return;
428 
429  std::remove(trainer->trainFileName(this, "root", "input").c_str());
430  std::remove(trainer->trainFileName(this, "root", "output").c_str());
431  for(std::vector<Method>::const_iterator iter = methods.begin();
432  iter != methods.end(); ++iter) {
433  std::remove(getWeightsFile(*iter, "xml").c_str());
434  std::remove(getWeightsFile(*iter, "root").c_str());
435  }
436  rmdir("weights");
437 }
438 
439 } // anonymous namespace
440 
441 MVA_TRAINER_DEFINE_PLUGIN(ProcTMVA);
type
Definition: HCALResponse.h:22
int i
Definition: DBlmapReader.cc:9
MVA_TRAINER_DEFINE_PLUGIN(ProcTMVA)
static void cleanup(const Factory::MakerMap::value_type &v)
Definition: Factory.cc:12
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
Definition: FindCaloHit.cc:7
detail::ThreadSafeRegistry< ParameterSetID, ParameterSet, ProcessParameterSetIDCache > Registry
Definition: Registry.h:37
tuple node
Definition: Node.py:50
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:32
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
tuple iteration
Definition: align_cfg.py:5
def load
Definition: svgfig.py:546
#define end
Definition: vmac.h:38
std::string getName(Reflex::Type &cc)
Definition: ClassFiller.cc:18
tuple description
Definition: idDealer.py:66
#define begin
Definition: vmac.h:31
template to generate a registry singleton for a type.
dbl *** dir
Definition: mlp_gen.cc:35
static Interceptor::Registry registry("Interceptor")
static const HistoName names[]
std::vector< unsigned char > store
Definition: MVAComputer.h:148
tuple size
Write out results.