CMS 3D CMS Logo

TreeSaver.cc
Go to the documentation of this file.
1 #include <unistd.h>
2 #include <functional>
3 #include <algorithm>
4 #include <iostream>
5 #include <sstream>
6 #include <fstream>
7 #include <cstddef>
8 #include <cstring>
9 #include <cstdio>
10 #include <vector>
11 #include <memory>
12 
13 #include <xercesc/dom/DOM.hpp>
14 
15 #include <TDirectory.h>
16 #include <TTree.h>
17 #include <TFile.h>
18 #include <TCut.h>
19 
21 
25 
31 
33 
34 using namespace PhysicsTools;
35 
36 namespace { // anonymous
37 
38 class ROOTContextSentinel {
39  public:
40  ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
41  ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
42 
43  private:
44  TDirectory *dir;
45  TFile *file;
46 };
47 
48 class TreeSaver : public TrainProcessor {
49  public:
51 
52  TreeSaver(const char *name, const AtomicId *id,
53  MVATrainer *trainer);
54  virtual ~TreeSaver();
55 
56  virtual void configure(DOMElement *elem) override;
57  virtual void passFlags(const std::vector<Variable::Flags> &flags) override;
58 
59  virtual void trainBegin() override;
60  virtual void trainData(const std::vector<double> *values,
61  bool target, double weight) override;
62  virtual void trainEnd() override;
63 
64  private:
65  void init();
66 
67  std::string getTreeName() const
68  { return trainer->getName() + '_' + (const char*)getName(); }
69 
70  enum Iteration {
71  ITER_EXPORT,
72  ITER_DONE
73  } iteration;
74 
75  struct Var {
78  double value;
79  std::vector<double> values;
80  std::vector<double> *ptr;
81 
82  bool hasName(std::string other) const
83  { return name == other; }
84  };
85 
86  std::unique_ptr<TFile> file;
87  TTree *tree;
88  Double_t weight;
89  Bool_t target;
90  std::vector<Var> vars;
91  bool flagsPassed, begun;
92 };
93 
94 TreeSaver::Registry registry("TreeSaver");
95 
96 TreeSaver::TreeSaver(const char *name, const AtomicId *id,
97  MVATrainer *trainer) :
98  TrainProcessor(name, id, trainer),
99  iteration(ITER_EXPORT), tree(0), flagsPassed(false), begun(false)
100 {
101 }
102 
103 TreeSaver::~TreeSaver()
104 {
105 }
106 
107 void TreeSaver::configure(DOMElement *elem)
108 {
109  std::vector<SourceVariable*> inputs = getInputs().get();
110 
111  for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
112  iter != inputs.end(); iter++) {
113  std::string name = (const char*)(*iter)->getName();
114 
115  if (std::find_if(vars.begin(), vars.end(),
116  std::bind2nd(std::mem_fun_ref(&Var::hasName),
117  name)) != vars.end()) {
118  for(unsigned i = 1;; i++) {
119  std::ostringstream ss;
120  ss << name << "_" << i;
121  if (std::find_if(vars.begin(), vars.end(),
122  std::bind2nd(
123  std::mem_fun_ref(
124  &Var::hasName),
125  ss.str())) ==
126  vars.end()) {
127  name = ss.str();
128  break;
129  }
130  }
131  }
132 
133  Var var;
134  var.name = name;
135  var.flags = Variable::FLAG_NONE;
136  var.ptr = 0;
137  vars.push_back(var);
138  }
139 }
140 
141 void TreeSaver::init()
142 {
143  tree->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
144  tree->Branch("__TARGET__", &target, "__TARGET__/O");
145 
146  vars.resize(vars.size());
147 
148  std::vector<Var>::iterator pos = vars.begin();
149  for(std::vector<Var>::iterator iter = vars.begin();
150  iter != vars.end(); iter++, pos++) {
151  if (iter->flags & Variable::FLAG_MULTIPLE) {
152  iter->ptr = &iter->values;
153  tree->Branch(iter->name.c_str(),
154  "std::vector<double>",
155  &pos->ptr);
156  } else
157  tree->Branch(iter->name.c_str(), &pos->value,
158  (iter->name + "/D").c_str());
159  }
160 }
161 
162 void TreeSaver::passFlags(const std::vector<Variable::Flags> &flags)
163 {
164  assert(flags.size() == vars.size());
165  unsigned int idx = 0;
166  for(std::vector<Variable::Flags>::const_iterator iter = flags.begin();
167  iter != flags.end(); ++iter, idx++)
168  vars[idx].flags = *iter;
169 
170  if (begun && !flagsPassed)
171  init();
172  flagsPassed = true;
173 }
174 
175 void TreeSaver::trainBegin()
176 {
177  if (iteration == ITER_EXPORT) {
178  ROOTContextSentinel ctx;
179 
180  file = std::unique_ptr<TFile>(TFile::Open(
181  trainer->trainFileName(this, "root").c_str(),
182  "RECREATE"));
183  if (!file.get())
184  throw cms::Exception("TreeSaver")
185  << "Could not open ROOT file for writing."
186  << std::endl;
187 
188  file->cd();
189  tree = new TTree(getTreeName().c_str(),
190  "MVATrainer signal and background");
191 
192  if (!begun && flagsPassed)
193  init();
194  begun = true;
195  }
196 }
197 
198 void TreeSaver::trainData(const std::vector<double> *values,
199  bool target, double weight)
200 {
201  if (iteration != ITER_EXPORT)
202  return;
203 
204  this->weight = weight;
205  this->target = target;
206  for(unsigned int i = 0; i < vars.size(); i++, values++) {
207  Var &var = vars[i];
208  if (var.flags & Variable::FLAG_MULTIPLE)
209  var.values = *values;
210  else if (values->empty())
211  var.value = -999.0;
212  else
213  var.value = values->front();
214  }
215 
216  tree->Fill();
217 }
218 
219 void TreeSaver::trainEnd()
220 {
221  switch(iteration) {
222  case ITER_EXPORT:
223  /* ROOT context-safe */ {
224  ROOTContextSentinel ctx;
225  file->cd();
226  tree->Write();
227  file->Close();
228  file.reset();
229  }
230  vars.clear();
231 
232  iteration = ITER_DONE;
233  trained = true;
234  break;
235  default:
236  /* shut up */;
237  }
238 }
239 
240 } // anonymous namespace
int init
Definition: HydjetWrapper.h:67
std::vector< Variable::Flags > flags
Definition: MVATrainer.cc:135
Definition: weight.py:1
template to generate a registry singleton for a type.
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:31
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:18
Definition: tree.py:1
dbl *** dir
Definition: mlp_gen.cc:35
static Interceptor::Registry registry("Interceptor")