12 #include <xercesc/dom/DOM.hpp>
14 #include <TDirectory.h>
19 #include <TMVA/Types.h>
20 #include <TMVA/Factory.h>
21 #include <TMVA/DataLoader.h>
35 XERCES_CPP_NAMESPACE_USE
37 using namespace PhysicsTools;
41 class ROOTContextSentinel {
43 ROOTContextSentinel() :
dir(gDirectory),
file(gFile) {}
44 ~ROOTContextSentinel() { gDirectory =
dir; gFile =
file; }
59 virtual void configure(DOMElement *
elem)
override;
62 virtual void trainBegin()
override;
63 virtual void trainData(
const std::vector<double> *
values,
65 virtual void trainEnd()
override;
67 virtual bool load()
override;
68 virtual void cleanup()
override;
71 void runTMVATrainer();
74 TMVA::Types::EMVA
type;
80 {
return trainer->getName() +
'_' + (
const char*)getName(); }
82 std::string getWeightsFile(
const Method &meth,
const char *ext)
const
84 return "weights/" + getTreeName() +
'_' +
85 meth.name +
".weights." + ext;
93 std::vector<Method> methods;
94 std::vector<std::string>
names;
95 std::auto_ptr<TFile>
file;
96 TTree *treeSig, *treeBkg;
98 std::vector<Double_t> vars;
100 unsigned long nSignal;
101 unsigned long nBackground;
102 bool doUserTreeSetup;
109 ProcTMVA::ProcTMVA(
const char *
name,
const AtomicId *
id,
112 iteration(ITER_EXPORT), treeSig(0), treeBkg(0), needCleanup(
false),
113 doUserTreeSetup(
false), setupOptions(
"SplitMode = Block:!V")
117 ProcTMVA::~ProcTMVA()
121 void ProcTMVA::configure(DOMElement *
elem)
123 std::vector<SourceVariable*>
inputs = getInputs().get();
125 for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
126 iter != inputs.end(); iter++) {
131 for(
unsigned i = 1;;
i++) {
132 std::ostringstream
ss;
133 ss << name <<
"_" <<
i;
135 ss.str()) ==
names.end()) {
142 names.push_back(name);
145 for(DOMNode *node = elem->getFirstChild();
146 node; node = node->getNextSibling()) {
147 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
150 bool isMethod = !std::strcmp(
XMLSimpleStr(node->getNodeName()),
"method");
151 bool isSetup = !std::strcmp(
XMLSimpleStr(node->getNodeName()),
"setup");
153 if (!isMethod && !isSetup)
155 <<
"Expected method or setup tag in config section."
158 elem =
static_cast<DOMElement*
>(node);
162 method.type = TMVA::Types::Instance().GetMethodType(
163 XMLDocument::readAttribute<std::string>(
164 elem,
"type").c_str());
167 XMLDocument::readAttribute<std::string>(
173 methods.push_back(method);
174 }
else if (isSetup) {
177 <<
"Multiple appeareances of setup "
178 "tag in config section."
181 doUserTreeSetup =
true;
184 XMLDocument::readAttribute<std::string>(
187 XMLDocument::readAttribute<std::string>(
194 <<
"Expected TMVA method in config section."
201 for(std::vector<Method>::const_iterator iter = methods.begin();
202 iter != methods.end(); ++iter) {
203 std::ifstream
in(getWeightsFile(*iter,
"xml").c_str());
218 static std::size_t getStreamSize(std::ifstream &
in)
220 std::ifstream::pos_type
begin = in.tellg();
222 std::ifstream::pos_type
end = in.tellg();
223 in.seekg(begin, std::ios::beg);
225 return (std::size_t)(end -
begin);
232 std::ifstream
in(getWeightsFile(methods[0],
"xml").c_str(),
236 <<
"Weights file " << getWeightsFile(methods[0],
"xml")
237 <<
" cannot be opened for reading." << std::endl;
239 std::size_t
size = getStreamSize(in) + methods[0].name.size();
240 for(std::vector<std::string>::const_iterator iter =
names.begin();
241 iter !=
names.end(); ++iter)
242 size += iter->size() + 1;
243 size += (size / 32) + 128;
245 std::shared_ptr<char> buffer(
new char[size] );
249 ozs << methods[0].name <<
"\n";
250 ozs <<
names.size() <<
"\n";
251 for(std::vector<std::string>::const_iterator iter =
253 iter !=
names.end(); ++iter)
254 ozs << *iter <<
"\n";
258 size = os.end() - os.begin();
259 calib->
store.resize(size);
260 std::memcpy(&calib->
store.front(), os.begin(),
size);
264 calib->
method =
"ProcTMVA";
269 void ProcTMVA::trainBegin()
272 ROOTContextSentinel ctx;
274 file = std::auto_ptr<TFile>(TFile::Open(
275 trainer->trainFileName(
this,
"root",
280 <<
"Could not open ROOT file for writing."
284 treeSig =
new TTree((getTreeName() +
"_sig").c_str(),
285 "MVATrainer signal");
286 treeBkg =
new TTree((getTreeName() +
"_bkg").c_str(),
287 "MVATrainer background");
289 treeSig->Branch(
"__WEIGHT__", &
weight,
"__WEIGHT__/D");
290 treeBkg->Branch(
"__WEIGHT__", &
weight,
"__WEIGHT__/D");
292 vars.resize(
names.size());
294 std::vector<Double_t>::iterator pos = vars.begin();
295 for(std::vector<std::string>::const_iterator iter =
296 names.begin(); iter !=
names.end(); iter++, pos++) {
297 treeSig->Branch(iter->c_str(), &*pos,
298 (*iter +
"/D").c_str());
299 treeBkg->Branch(iter->c_str(), &*pos,
300 (*iter +
"/D").c_str());
303 nSignal = nBackground = 0;
307 void ProcTMVA::trainData(
const std::vector<double> *
values,
314 for(
unsigned int i = 0; i < vars.size(); i++, values++)
315 vars[i] = values->front();
326 void ProcTMVA::runTMVATrainer()
330 if (nSignal < 1 || nBackground < 1)
332 <<
"Not going to run TMVA: "
333 "No signal (" << nSignal <<
") or background ("
334 << nBackground <<
") events!" << std::endl;
336 std::auto_ptr<TFile>
file(TFile::Open(
337 trainer->trainFileName(
this,
"root",
"output").c_str(),
341 <<
"Could not open TMVA ROOT file for writing."
344 std::unique_ptr<TMVA::Factory> factory(
345 new TMVA::Factory(getTreeName().c_str(),
file.get(),
""));
347 std::unique_ptr<TMVA::DataLoader> loader(
new TMVA::DataLoader(
"ProcTMVA"));
349 loader->SetInputTrees(treeSig, treeBkg);
351 for(std::vector<std::string>::const_iterator iter =
names.begin();
352 iter !=
names.end(); iter++)
353 loader->AddVariable(iter->c_str(),
'D');
355 loader->SetWeightExpression(
"__WEIGHT__");
357 if (doUserTreeSetup) {
358 loader->PrepareTrainingAndTestTree(
359 setupCuts.c_str(), setupOptions);
361 loader->PrepareTrainingAndTestTree(
363 "SplitMode=Block:!V");
366 for(std::vector<Method>::const_iterator iter = methods.begin();
367 iter != methods.end(); ++iter)
368 factory->BookMethod(loader.get(), iter->type, iter->name, iter->description);
370 factory->TrainAllMethods();
371 factory->TestAllMethods();
372 factory->EvaluateAllMethods();
379 printf(
"TMVA training factory completed\n");
382 void ProcTMVA::trainEnd()
387 ROOTContextSentinel ctx;
394 file = std::auto_ptr<TFile>(TFile::Open(
395 trainer->trainFileName(
this,
"root",
399 <<
"Could not open ROOT file for "
400 "reading." << std::endl;
401 treeSig =
dynamic_cast<TTree*
>(
402 file->Get((getTreeName() +
"_sig").c_str()));
403 treeBkg =
dynamic_cast<TTree*
>(
404 file->Get((getTreeName() +
"_bkg").c_str()));
428 std::remove(trainer->trainFileName(
this,
"root",
"input").c_str());
429 std::remove(trainer->trainFileName(
this,
"root",
"output").c_str());
430 for(std::vector<Method>::const_iterator iter = methods.begin();
431 iter != methods.end(); ++iter) {
433 std::remove(getWeightsFile(*iter,
"root").c_str());
static const HistoName names[]
static void cleanup(const Factory::MakerMap::value_type &v)
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
MVATrainerComputer * calib
volatile std::atomic< bool > shutdown_flag false
tuple size
Write out results.
#define MVA_TRAINER_DEFINE_PLUGIN(T)