CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
TreeTrainer.cc
Go to the documentation of this file.
1 #include <functional>
2 #include <algorithm>
3 #include <string>
4 #include <vector>
5 
6 #include <TString.h>
7 #include <TTree.h>
8 
10 
13 
16 
17 namespace PhysicsTools {
18 
20 {
21 }
22 
24 {
25  addTree(tree, -1, weight);
26 }
27 
28 TreeTrainer::TreeTrainer(TTree *signal, TTree *background, double weight)
29 {
30  addTree(signal, true, weight);
31  addTree(background, false, weight);
32 }
33 
35 {
36  reset();
37 }
38 
40  double crossValidation,
41  bool useXSLT)
42 {
43  MVATrainer trainer(trainFile, useXSLT);
44  trainer.setMonitoring(true);
45  trainer.setCrossValidation(crossValidation);
46  train(&trainer);
47  return trainer.getCalibration();
48 }
49 
51 {
52  readers.clear();
53  std::for_each(weights.begin(), weights.end(),
54  std::ptr_fun(&::operator delete));
55  weights.clear();
56 }
57 
58 void TreeTrainer::addTree(TTree *tree, int target, double weight)
59 {
60  static const bool targets[2] = { true, false };
61 
62  TreeReader reader(tree, false, weight > 0.0);
63 
64  if (target >= 0) {
65  if (tree->GetBranch("__TARGET__"))
66  throw cms::Exception("TreeTrainer")
67  << "__TARGET__ branch already present in file."
68  << std::endl;
69 
70  reader.addSingle(MVATrainer::kTargetId, &targets[!target]);
71  }
72 
73  if (weight > 0.0) {
74  double *ptr = new double(weight);
75  weights.push_back(ptr);
76  reader.addSingle(MVATrainer::kWeightId, ptr);
77  }
78 
79  addReader(reader);
80 }
81 
83 {
84  readers.push_back(reader);
85 }
86 
88 {
90  if (!calib)
91  return true;
92 
93  MVAComputer computer(calib, true);
94 
95  std::for_each(readers.begin(), readers.end(),
96  std::bind2nd(std::mem_fun_ref(&TreeReader::loop),
97  &computer));
98 
99  return false;
100 }
101 
103 {
104  while(!iteration(trainer));
105 }
106 
107 } // namespace PhysicsTools
Calibration::MVAComputer * train(const std::string &trainDescription, double crossValidation=0.0, bool useXSLT=false)
Definition: TreeTrainer.cc:39
bool iteration(MVATrainer *trainer)
Definition: TreeTrainer.cc:87
void setMonitoring(bool monitoring)
Definition: MVATrainer.h:35
static const AtomicId kTargetId
Definition: MVATrainer.h:59
void addSingle(AtomicId name, const T *value, bool opt=false)
static const AtomicId kWeightId
Definition: MVATrainer.h:60
MVATrainerComputer * calib
Definition: MVATrainer.cc:64
std::vector< TreeReader > readers
Definition: TreeTrainer.h:40
Main interface class to the generic discriminator computer framework.
Definition: MVAComputer.h:39
void addTree(TTree *tree, int target=-1, double weight=-1.0)
Definition: TreeTrainer.cc:58
std::vector< double * > weights
Definition: TreeTrainer.h:42
void addReader(const TreeReader &reader)
Definition: TreeTrainer.cc:82
uint64_t loop(const MVAComputer *mva)
Definition: TreeReader.cc:298
Calibration::MVAComputer * getCalibration() const
Definition: MVATrainer.cc:1146
void setCrossValidation(double split)
Definition: MVATrainer.h:37
int weight
Definition: histoStyle.py:50
Calibration::MVAComputer * getTrainCalibration() const
Definition: MVATrainer.cc:1241