12 #include <xercesc/dom/DOM.hpp>
14 #include <TDirectory.h>
19 #include <TMVA/Types.h>
20 #include <TMVA/Factory.h>
34 XERCES_CPP_NAMESPACE_USE
36 using namespace PhysicsTools;
40 class ROOTContextSentinel {
42 ROOTContextSentinel() :
dir(gDirectory),
file(gFile) {}
43 ~ROOTContextSentinel() { gDirectory =
dir; gFile =
file; }
58 virtual void configure(DOMElement *
elem)
override;
61 virtual void trainBegin()
override;
62 virtual void trainData(
const std::vector<double> *
values,
64 virtual void trainEnd()
override;
66 virtual bool load()
override;
67 virtual void cleanup()
override;
70 void runTMVATrainer();
73 TMVA::Types::EMVA
type;
79 {
return trainer->getName() +
'_' + (
const char*)getName(); }
81 std::string getWeightsFile(
const Method &meth,
const char *ext)
const
83 return "weights/" + getTreeName() +
'_' +
84 meth.name +
".weights." + ext;
92 std::vector<Method> methods;
93 std::vector<std::string>
names;
94 std::auto_ptr<TFile>
file;
95 TTree *treeSig, *treeBkg;
97 std::vector<Double_t> vars;
99 unsigned long nSignal;
100 unsigned long nBackground;
101 bool doUserTreeSetup;
108 ProcTMVA::ProcTMVA(
const char *
name,
const AtomicId *
id,
111 iteration(ITER_EXPORT), treeSig(0), treeBkg(0), needCleanup(
false),
112 doUserTreeSetup(
false), setupOptions(
"SplitMode = Block:!V")
116 ProcTMVA::~ProcTMVA()
120 void ProcTMVA::configure(DOMElement *
elem)
122 std::vector<SourceVariable*>
inputs = getInputs().get();
124 for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
125 iter != inputs.end(); iter++) {
130 for(
unsigned i = 1;;
i++) {
131 std::ostringstream
ss;
132 ss << name <<
"_" <<
i;
134 ss.str()) ==
names.end()) {
141 names.push_back(name);
144 for(DOMNode *
node = elem->getFirstChild();
146 if (
node->getNodeType() != DOMNode::ELEMENT_NODE)
152 if (!isMethod && !isSetup)
154 <<
"Expected method or setup tag in config section."
157 elem =
static_cast<DOMElement*
>(
node);
161 method.type = TMVA::Types::Instance().GetMethodType(
162 XMLDocument::readAttribute<std::string>(
163 elem,
"type").c_str());
166 XMLDocument::readAttribute<std::string>(
172 methods.push_back(method);
173 }
else if (isSetup) {
176 <<
"Multiple appeareances of setup "
177 "tag in config section."
180 doUserTreeSetup =
true;
183 XMLDocument::readAttribute<std::string>(
186 XMLDocument::readAttribute<std::string>(
193 <<
"Expected TMVA method in config section."
200 for(std::vector<Method>::const_iterator iter = methods.begin();
201 iter != methods.end(); ++iter) {
202 std::ifstream
in(getWeightsFile(*iter,
"xml").c_str());
217 static std::size_t getStreamSize(std::ifstream &
in)
219 std::ifstream::pos_type
begin = in.tellg();
221 std::ifstream::pos_type
end = in.tellg();
222 in.seekg(begin, std::ios::beg);
224 return (std::size_t)(end -
begin);
231 std::ifstream
in(getWeightsFile(methods[0],
"xml").c_str(),
235 <<
"Weights file " << getWeightsFile(methods[0],
"xml")
236 <<
" cannot be opened for reading." << std::endl;
238 std::size_t
size = getStreamSize(in) + methods[0].name.size();
239 for(std::vector<std::string>::const_iterator iter =
names.begin();
240 iter !=
names.end(); ++iter)
241 size += iter->size() + 1;
242 size += (size / 32) + 128;
244 std::shared_ptr<char> buffer(
new char[size] );
248 ozs << methods[0].name <<
"\n";
249 ozs <<
names.size() <<
"\n";
250 for(std::vector<std::string>::const_iterator iter =
252 iter !=
names.end(); ++iter)
253 ozs << *iter <<
"\n";
257 size = os.end() - os.begin();
258 calib->
store.resize(size);
259 std::memcpy(&calib->
store.front(), os.begin(),
size);
263 calib->
method =
"ProcTMVA";
268 void ProcTMVA::trainBegin()
271 ROOTContextSentinel ctx;
273 file = std::auto_ptr<TFile>(TFile::Open(
274 trainer->trainFileName(
this,
"root",
279 <<
"Could not open ROOT file for writing."
283 treeSig =
new TTree((getTreeName() +
"_sig").c_str(),
284 "MVATrainer signal");
285 treeBkg =
new TTree((getTreeName() +
"_bkg").c_str(),
286 "MVATrainer background");
288 treeSig->Branch(
"__WEIGHT__", &
weight,
"__WEIGHT__/D");
289 treeBkg->Branch(
"__WEIGHT__", &
weight,
"__WEIGHT__/D");
291 vars.resize(
names.size());
293 std::vector<Double_t>::iterator pos = vars.begin();
294 for(std::vector<std::string>::const_iterator iter =
295 names.begin(); iter !=
names.end(); iter++, pos++) {
296 treeSig->Branch(iter->c_str(), &*pos,
297 (*iter +
"/D").c_str());
298 treeBkg->Branch(iter->c_str(), &*pos,
299 (*iter +
"/D").c_str());
302 nSignal = nBackground = 0;
306 void ProcTMVA::trainData(
const std::vector<double> *
values,
313 for(
unsigned int i = 0; i < vars.size(); i++, values++)
314 vars[i] = values->front();
325 void ProcTMVA::runTMVATrainer()
329 if (nSignal < 1 || nBackground < 1)
331 <<
"Not going to run TMVA: "
332 "No signal (" << nSignal <<
") or background ("
333 << nBackground <<
") events!" << std::endl;
335 std::auto_ptr<TFile>
file(TFile::Open(
336 trainer->trainFileName(
this,
"root",
"output").c_str(),
340 <<
"Could not open TMVA ROOT file for writing."
343 std::auto_ptr<TMVA::Factory> factory(
344 new TMVA::Factory(getTreeName().c_str(),
file.get(),
""));
346 factory->SetInputTrees(treeSig, treeBkg);
348 for(std::vector<std::string>::const_iterator iter =
names.begin();
349 iter !=
names.end(); iter++)
350 factory->AddVariable(iter->c_str(),
'D');
352 factory->SetWeightExpression(
"__WEIGHT__");
355 factory->PrepareTrainingAndTestTree(
356 setupCuts.c_str(), setupOptions);
358 factory->PrepareTrainingAndTestTree(
360 "SplitMode=Block:!V");
362 for(std::vector<Method>::const_iterator iter = methods.begin();
363 iter != methods.end(); ++iter)
364 factory->BookMethod(iter->type, iter->name, iter->description);
366 factory->TrainAllMethods();
367 factory->TestAllMethods();
368 factory->EvaluateAllMethods();
374 printf(
"TMVA training factory completed\n");
377 void ProcTMVA::trainEnd()
382 ROOTContextSentinel ctx;
389 file = std::auto_ptr<TFile>(TFile::Open(
390 trainer->trainFileName(
this,
"root",
394 <<
"Could not open ROOT file for "
395 "reading." << std::endl;
396 treeSig =
dynamic_cast<TTree*
>(
397 file->Get((getTreeName() +
"_sig").c_str()));
398 treeBkg =
dynamic_cast<TTree*
>(
399 file->Get((getTreeName() +
"_bkg").c_str()));
423 std::remove(trainer->trainFileName(
this,
"root",
"input").c_str());
424 std::remove(trainer->trainFileName(
this,
"root",
"output").c_str());
425 for(std::vector<Method>::const_iterator iter = methods.begin();
426 iter != methods.end(); ++iter) {
428 std::remove(getWeightsFile(*iter,
"root").c_str());
static const HistoName names[]
static void cleanup(const Factory::MakerMap::value_type &v)
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
MVATrainerComputer * calib
volatile std::atomic< bool > shutdown_flag false
tuple size
Write out results.
#define MVA_TRAINER_DEFINE_PLUGIN(T)