CMS 3D CMS Logo

TreeSaver.cc
Go to the documentation of this file.
1 #include <unistd.h>
2 #include <functional>
3 #include <algorithm>
4 #include <iostream>
5 #include <sstream>
6 #include <fstream>
7 #include <cstddef>
8 #include <cstring>
9 #include <cstdio>
10 #include <vector>
11 #include <memory>
12 
13 #include <xercesc/dom/DOM.hpp>
14 
15 #include <TDirectory.h>
16 #include <TTree.h>
17 #include <TFile.h>
18 #include <TCut.h>
19 
21 
25 
31 
33 
34 using namespace PhysicsTools;
35 
36 namespace { // anonymous
37 
38 class ROOTContextSentinel {
39  public:
40  ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
41  ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
42 
43  private:
44  TDirectory *dir;
45  TFile *file;
46 };
47 
48 class TreeSaver : public TrainProcessor {
49  public:
51 
52  TreeSaver(const char *name, const AtomicId *id,
53  MVATrainer *trainer);
54  ~TreeSaver() override;
55 
56  void configure(DOMElement *elem) override;
57  void passFlags(const std::vector<Variable::Flags> &flags) override;
58 
59  void trainBegin() override;
60  void trainData(const std::vector<double> *values,
61  bool target, double weight) override;
62  void trainEnd() override;
63 
64  private:
65  void init();
66 
67  std::string getTreeName() const
68  { return trainer->getName() + '_' + (const char*)getName(); }
69 
70  enum Iteration {
71  ITER_EXPORT,
72  ITER_DONE
73  } iteration;
74 
75  struct Var {
78  double value;
79  std::vector<double> values;
80  std::vector<double> *ptr;
81 
82  bool hasName(std::string other) const
83  { return name == other; }
84  };
85 
86  std::unique_ptr<TFile> file;
87  TTree *tree;
88  Double_t weight;
89  Bool_t target;
90  std::vector<Var> vars;
91  bool flagsPassed, begun;
92 };
93 
94 TreeSaver::Registry registry("TreeSaver");
95 
96 TreeSaver::TreeSaver(const char *name, const AtomicId *id,
97  MVATrainer *trainer) :
98  TrainProcessor(name, id, trainer),
99  iteration(ITER_EXPORT), tree(nullptr), flagsPassed(false), begun(false)
100 {
101 }
102 
103 TreeSaver::~TreeSaver()
104 {
105 }
106 
107 void TreeSaver::configure(DOMElement *elem)
108 {
109  std::vector<SourceVariable*> inputs = getInputs().get();
110 
111  for( auto const& input : inputs ) {
112  std::string name = static_cast<const char*>(input->getName());
113 
114  if (std::find_if(vars.begin(), vars.end(),
115  [&name](auto const& c){return c.hasName(name);})
116  != vars.end()) {
117  for(unsigned i = 1;; i++) {
118  std::ostringstream ss;
119  ss << name << "_" << i;
120  if (std::find_if(vars.begin(), vars.end(),
121  [&ss](auto c){return c.hasName(ss.str());})
122  == vars.end()) {
123  name = ss.str();
124  break;
125  }
126  }
127  }
128 
129  Var var;
130  var.name = name;
131  var.flags = Variable::FLAG_NONE;
132  var.ptr = nullptr;
133  vars.push_back(var);
134  }
135 }
136 
137 void TreeSaver::init()
138 {
139  tree->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
140  tree->Branch("__TARGET__", &target, "__TARGET__/O");
141 
142  vars.resize(vars.size());
143 
144  std::vector<Var>::iterator pos = vars.begin();
145  for(std::vector<Var>::iterator iter = vars.begin();
146  iter != vars.end(); iter++, pos++) {
147  if (iter->flags & Variable::FLAG_MULTIPLE) {
148  iter->ptr = &iter->values;
149  tree->Branch(iter->name.c_str(),
150  "std::vector<double>",
151  &pos->ptr);
152  } else
153  tree->Branch(iter->name.c_str(), &pos->value,
154  (iter->name + "/D").c_str());
155  }
156 }
157 
158 void TreeSaver::passFlags(const std::vector<Variable::Flags> &flags)
159 {
160  assert(flags.size() == vars.size());
161  unsigned int idx = 0;
162  for(std::vector<Variable::Flags>::const_iterator iter = flags.begin();
163  iter != flags.end(); ++iter, idx++)
164  vars[idx].flags = *iter;
165 
166  if (begun && !flagsPassed)
167  init();
168  flagsPassed = true;
169 }
170 
171 void TreeSaver::trainBegin()
172 {
173  if (iteration == ITER_EXPORT) {
174  ROOTContextSentinel ctx;
175 
176  file = std::unique_ptr<TFile>(TFile::Open(
177  trainer->trainFileName(this, "root").c_str(),
178  "RECREATE"));
179  if (!file.get())
180  throw cms::Exception("TreeSaver")
181  << "Could not open ROOT file for writing."
182  << std::endl;
183 
184  file->cd();
185  tree = new TTree(getTreeName().c_str(),
186  "MVATrainer signal and background");
187 
188  if (!begun && flagsPassed)
189  init();
190  begun = true;
191  }
192 }
193 
194 void TreeSaver::trainData(const std::vector<double> *values,
195  bool target, double weight)
196 {
197  if (iteration != ITER_EXPORT)
198  return;
199 
200  this->weight = weight;
201  this->target = target;
202  for(unsigned int i = 0; i < vars.size(); i++, values++) {
203  Var &var = vars[i];
204  if (var.flags & Variable::FLAG_MULTIPLE)
205  var.values = *values;
206  else if (values->empty())
207  var.value = -999.0;
208  else
209  var.value = values->front();
210  }
211 
212  tree->Fill();
213 }
214 
215 void TreeSaver::trainEnd()
216 {
217  switch(iteration) {
218  case ITER_EXPORT:
219  /* ROOT context-safe */ {
220  ROOTContextSentinel ctx;
221  file->cd();
222  tree->Write();
223  file->Close();
224  file.reset();
225  }
226  vars.clear();
227 
228  iteration = ITER_DONE;
229  trained = true;
230  break;
231  default:
232  /* shut up */;
233  }
234 }
235 
236 } // anonymous namespace
int init
Definition: HydjetWrapper.h:67
std::vector< Variable::Flags > flags
Definition: MVATrainer.cc:135
Definition: weight.py:1
template to generate a registry singleton for a type.
#define nullptr
Cheap generic unique keyword identifier class.
Definition: AtomicId.h:31
static std::string const input
Definition: EdmProvDump.cc:45
def Var(expr, valtype, compression=None, doc=None, mcOnly=False, precision=-1)
Definition: common_cff.py:20
def elem(elemtype, innerHTML='', html_class='', kwargs)
Definition: HTMLExport.py:18
Definition: tree.py:1
dbl *** dir
Definition: mlp_gen.cc:35
static Interceptor::Registry registry("Interceptor")