13 #include <xercesc/dom/DOM.hpp>
15 #include <TDirectory.h>
32 XERCES_CPP_NAMESPACE_USE
34 using namespace PhysicsTools;
38 class ROOTContextSentinel {
40 ROOTContextSentinel() :
dir(gDirectory),
file(gFile) {}
41 ~ROOTContextSentinel() { gDirectory =
dir; gFile =
file; }
56 virtual void configure(DOMElement *
elem);
57 virtual void passFlags(
const std::vector<Variable::Flags> &
flags);
59 virtual void trainBegin();
60 virtual void trainData(
const std::vector<double> *
values,
62 virtual void trainEnd();
67 std::string getTreeName()
const
68 {
return trainer->getName() +
'_' + (
const char*)
getName(); }
79 std::vector<double>
values;
80 std::vector<double> *ptr;
82 bool hasName(std::string other)
const
83 {
return name == other; }
86 std::auto_ptr<TFile>
file;
90 std::vector<Var> vars;
91 bool flagsPassed, begun;
96 TreeSaver::TreeSaver(
const char *
name,
const AtomicId *
id,
103 TreeSaver::~TreeSaver()
107 void TreeSaver::configure(DOMElement *
elem)
109 std::vector<SourceVariable*>
inputs = getInputs().get();
111 for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
112 iter != inputs.end(); iter++) {
113 std::string
name = (
const char*)(*iter)->getName();
115 if (std::find_if(vars.begin(), vars.end(),
116 std::bind2nd(std::mem_fun_ref(&Var::hasName),
117 name)) != vars.end()) {
118 for(
unsigned i = 1;;
i++) {
119 std::ostringstream ss;
120 ss << name <<
"_" <<
i;
121 if (std::find_if(vars.begin(), vars.end(),
143 tree->Branch(
"__WEIGHT__", &
weight,
"__WEIGHT__/D");
144 tree->Branch(
"__TARGET__", &
target,
"__TARGET__/O");
146 vars.resize(vars.size());
148 std::vector<Var>::iterator
pos = vars.begin();
149 for(std::vector<Var>::iterator iter = vars.begin();
150 iter != vars.end(); iter++, pos++) {
152 iter->ptr = &iter->values;
153 tree->Branch(iter->name.c_str(),
154 "std::vector<double>",
157 tree->Branch(iter->name.c_str(), &pos->value,
158 (iter->name +
"/D").c_str());
162 void TreeSaver::passFlags(
const std::vector<Variable::Flags> &
flags)
164 assert(flags.size() == vars.size());
165 unsigned int idx = 0;
166 for(std::vector<Variable::Flags>::const_iterator iter = flags.begin();
167 iter != flags.end(); ++iter, idx++)
168 vars[idx].flags = *iter;
170 if (begun && !flagsPassed)
175 void TreeSaver::trainBegin()
178 ROOTContextSentinel ctx;
180 file = std::auto_ptr<TFile>(TFile::Open(
181 trainer->trainFileName(
this,
"root").c_str(),
185 <<
"Could not open ROOT file for writing."
189 tree =
new TTree(getTreeName().c_str(),
190 "MVATrainer signal and background");
192 if (!begun && flagsPassed)
198 void TreeSaver::trainData(
const std::vector<double> *
values,
206 for(
unsigned int i = 0; i < vars.size(); i++, values++) {
210 else if (values->empty())
213 var.value = values->front();
219 void TreeSaver::trainEnd()
224 ROOTContextSentinel ctx;
std::vector< Variable::Flags > flags
detail::ThreadSafeRegistry< ParameterSetID, ParameterSet, ProcessParameterSetIDCache > Registry
std::string getName(Reflex::Type &cc)