13 #include <xercesc/dom/DOM.hpp> 15 #include <TDirectory.h> 38 class ROOTContextSentinel {
40 ROOTContextSentinel() :
dir(gDirectory),
file(gFile) {}
41 ~ROOTContextSentinel() { gDirectory =
dir; gFile =
file; }
54 ~TreeSaver()
override;
56 void configure(DOMElement *
elem)
override;
57 void passFlags(
const std::vector<Variable::Flags> &
flags)
override;
59 void trainBegin()
override;
60 void trainData(
const std::vector<double> *
values,
62 void trainEnd()
override;
68 {
return trainer->getName() +
'_' + (
const char*)getName(); }
79 std::vector<double>
values;
80 std::vector<double> *ptr;
83 {
return name ==
other; }
86 std::unique_ptr<TFile>
file;
90 std::vector<Var> vars;
91 bool flagsPassed, begun;
96 TreeSaver::TreeSaver(
const char *
name,
const AtomicId *
id,
103 TreeSaver::~TreeSaver()
107 void TreeSaver::configure(DOMElement *
elem)
109 std::vector<SourceVariable*>
inputs = getInputs().get();
111 for(
auto const&
input : inputs ) {
114 if (std::find_if(vars.begin(), vars.end(),
115 [&
name](
auto const&
c){
return c.hasName(name);})
117 for(
unsigned i = 1;;
i++) {
118 std::ostringstream ss;
119 ss << name <<
"_" <<
i;
120 if (std::find_if(vars.begin(), vars.end(),
121 [&ss](
auto c){
return c.hasName(ss.str());})
139 tree->Branch(
"__WEIGHT__", &
weight,
"__WEIGHT__/D");
140 tree->Branch(
"__TARGET__", &
target,
"__TARGET__/O");
142 vars.resize(vars.size());
144 std::vector<Var>::iterator
pos = vars.begin();
145 for(std::vector<Var>::iterator iter = vars.begin();
146 iter != vars.end(); iter++, pos++) {
148 iter->ptr = &iter->values;
149 tree->Branch(iter->name.c_str(),
150 "std::vector<double>",
153 tree->Branch(iter->name.c_str(), &pos->value,
154 (iter->name +
"/D").c_str());
158 void TreeSaver::passFlags(
const std::vector<Variable::Flags> &
flags)
160 assert(flags.size() == vars.size());
161 unsigned int idx = 0;
162 for(std::vector<Variable::Flags>::const_iterator iter = flags.begin();
163 iter != flags.end(); ++iter, idx++)
164 vars[idx].flags = *iter;
166 if (begun && !flagsPassed)
171 void TreeSaver::trainBegin()
174 ROOTContextSentinel ctx;
176 file = std::unique_ptr<TFile>(TFile::Open(
177 trainer->trainFileName(
this,
"root").c_str(),
181 <<
"Could not open ROOT file for writing." 185 tree =
new TTree(getTreeName().c_str(),
186 "MVATrainer signal and background");
188 if (!begun && flagsPassed)
194 void TreeSaver::trainData(
const std::vector<double> *
values,
202 for(
unsigned int i = 0; i < vars.size(); i++, values++) {
206 else if (values->empty())
209 var.value = values->front();
215 void TreeSaver::trainEnd()
220 ROOTContextSentinel ctx;
std::vector< Variable::Flags > flags
static std::string const input
def Var(expr, valtype, compression=None, doc=None, mcOnly=False, precision=-1)
def elem(elemtype, innerHTML='', html_class='', kwargs)