12 #include <xercesc/dom/DOM.hpp>
14 #include <TDirectory.h>
19 #include <TMVA/Types.h>
20 #include <TMVA/Factory.h>
34 XERCES_CPP_NAMESPACE_USE
36 using namespace PhysicsTools;
40 class ROOTContextSentinel {
42 ROOTContextSentinel() :
dir(gDirectory),
file(gFile) {}
43 ~ROOTContextSentinel() { gDirectory =
dir; gFile =
file; }
58 virtual void configure(DOMElement *
elem);
61 virtual void trainBegin();
62 virtual void trainData(
const std::vector<double> *
values,
64 virtual void trainEnd();
70 void runTMVATrainer();
73 TMVA::Types::EMVA
type;
78 std::string getTreeName()
const
79 {
return trainer->getName() +
'_' + (
const char*)
getName(); }
81 std::string getWeightsFile(
const Method &meth,
const char *ext)
const
83 return "weights/" + getTreeName() +
'_' +
84 meth.name +
".weights." + ext;
92 std::vector<Method> methods;
93 std::vector<std::string>
names;
94 std::auto_ptr<TFile>
file;
95 TTree *treeSig, *treeBkg;
97 std::vector<Double_t> vars;
99 unsigned long nSignal;
100 unsigned long nBackground;
101 bool doUserTreeSetup;
102 std::string setupCuts;
103 std::string setupOptions;
108 ProcTMVA::ProcTMVA(
const char *
name,
const AtomicId *
id,
111 iteration(ITER_EXPORT), treeSig(0), treeBkg(0), needCleanup(
false),
112 doUserTreeSetup(
false), setupOptions(
"SplitMode = Block:!V")
116 ProcTMVA::~ProcTMVA()
120 void ProcTMVA::configure(DOMElement *
elem)
122 std::vector<SourceVariable*>
inputs = getInputs().get();
124 for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
125 iter != inputs.end(); iter++) {
126 std::string
name = (
const char*)(*iter)->getName();
130 for(
unsigned i = 1;;
i++) {
131 std::ostringstream ss;
132 ss << name <<
"_" <<
i;
134 ss.str()) ==
names.end()) {
141 names.push_back(name);
144 for(DOMNode *
node = elem->getFirstChild();
146 if (
node->getNodeType() != DOMNode::ELEMENT_NODE)
152 if (!isMethod && !isSetup)
154 <<
"Expected method or setup tag in config section."
157 elem =
static_cast<DOMElement*
>(
node);
161 method.type = TMVA::Types::Instance().GetMethodType(
162 XMLDocument::readAttribute<std::string>(
163 elem,
"type").c_str());
166 XMLDocument::readAttribute<std::string>(
172 methods.push_back(method);
173 }
else if (isSetup) {
176 <<
"Multiple appeareances of setup "
177 "tag in config section."
180 doUserTreeSetup =
true;
183 XMLDocument::readAttribute<std::string>(
186 XMLDocument::readAttribute<std::string>(
193 <<
"Expected TMVA method in config section."
200 for(std::vector<Method>::const_iterator iter = methods.begin();
201 iter != methods.end(); ++iter) {
202 std::ifstream
in(getWeightsFile(*iter,
"xml").c_str());
217 static std::size_t getStreamSize(std::ifstream &
in)
219 std::ifstream::pos_type
begin = in.tellg();
221 std::ifstream::pos_type
end = in.tellg();
222 in.seekg(begin, std::ios::beg);
224 return (std::size_t)(end -
begin);
231 std::ifstream
in(getWeightsFile(methods[0],
"xml").c_str(),
235 <<
"Weights file " << getWeightsFile(methods[0],
"xml")
236 <<
" cannot be opened for reading." << std::endl;
238 std::size_t
size = getStreamSize(in) + methods[0].name.size();
239 for(std::vector<std::string>::const_iterator iter =
names.begin();
240 iter !=
names.end(); ++iter)
241 size += iter->size() + 1;
242 size += (size / 32) + 128;
246 buffer =
new char[
size];
250 ozs << methods[0].name <<
"\n";
251 ozs <<
names.size() <<
"\n";
252 for(std::vector<std::string>::const_iterator iter =
254 iter !=
names.end(); ++iter)
255 ozs << *iter <<
"\n";
259 size = os.end() - os.begin();
260 calib->
store.resize(size);
261 std::memcpy(&calib->
store.front(), os.begin(),
size);
269 calib->
method =
"ProcTMVA";
274 void ProcTMVA::trainBegin()
277 ROOTContextSentinel ctx;
279 file = std::auto_ptr<TFile>(TFile::Open(
280 trainer->trainFileName(
this,
"root",
285 <<
"Could not open ROOT file for writing."
289 treeSig =
new TTree((getTreeName() +
"_sig").c_str(),
290 "MVATrainer signal");
291 treeBkg =
new TTree((getTreeName() +
"_bkg").c_str(),
292 "MVATrainer background");
294 treeSig->Branch(
"__WEIGHT__", &
weight,
"__WEIGHT__/D");
295 treeBkg->Branch(
"__WEIGHT__", &
weight,
"__WEIGHT__/D");
297 vars.resize(
names.size());
299 std::vector<Double_t>::iterator
pos = vars.begin();
300 for(std::vector<std::string>::const_iterator iter =
301 names.begin(); iter !=
names.end(); iter++, pos++) {
302 treeSig->Branch(iter->c_str(), &*
pos,
303 (*iter +
"/D").c_str());
304 treeBkg->Branch(iter->c_str(), &*
pos,
305 (*iter +
"/D").c_str());
308 nSignal = nBackground = 0;
312 void ProcTMVA::trainData(
const std::vector<double> *
values,
319 for(
unsigned int i = 0; i < vars.size(); i++, values++)
320 vars[i] = values->front();
331 void ProcTMVA::runTMVATrainer()
335 if (nSignal < 1 || nBackground < 1)
337 <<
"Not going to run TMVA: "
338 "No signal (" << nSignal <<
") or background ("
339 << nBackground <<
") events!" << std::endl;
341 std::auto_ptr<TFile>
file(TFile::Open(
342 trainer->trainFileName(
this,
"root",
"output").c_str(),
346 <<
"Could not open TMVA ROOT file for writing."
349 std::auto_ptr<TMVA::Factory> factory(
350 new TMVA::Factory(getTreeName().c_str(),
file.get(),
""));
352 factory->SetInputTrees(treeSig, treeBkg);
354 for(std::vector<std::string>::const_iterator iter =
names.begin();
355 iter !=
names.end(); iter++)
356 factory->AddVariable(iter->c_str(),
'D');
358 factory->SetWeightExpression(
"__WEIGHT__");
361 factory->PrepareTrainingAndTestTree(
362 setupCuts.c_str(), setupOptions);
364 factory->PrepareTrainingAndTestTree(
366 "SplitMode=Block:!V");
368 for(std::vector<Method>::const_iterator iter = methods.begin();
369 iter != methods.end(); ++iter)
370 factory->BookMethod(iter->type, iter->name, iter->description);
372 factory->TrainAllMethods();
373 factory->TestAllMethods();
374 factory->EvaluateAllMethods();
380 printf(
"TMVA training factory completed\n");
383 void ProcTMVA::trainEnd()
388 ROOTContextSentinel ctx;
395 file = std::auto_ptr<TFile>(TFile::Open(
396 trainer->trainFileName(
this,
"root",
400 <<
"Could not open ROOT file for "
401 "reading." << std::endl;
402 treeSig =
dynamic_cast<TTree*
>(
403 file->Get((getTreeName() +
"_sig").c_str()));
404 treeBkg =
dynamic_cast<TTree*
>(
405 file->Get((getTreeName() +
"_bkg").c_str()));
429 std::remove(trainer->trainFileName(
this,
"root",
"input").c_str());
430 std::remove(trainer->trainFileName(
this,
"root",
"output").c_str());
431 for(std::vector<Method>::const_iterator iter = methods.begin();
432 iter != methods.end(); ++iter) {
434 std::remove(getWeightsFile(*iter,
"root").c_str());
MVA_TRAINER_DEFINE_PLUGIN(ProcTMVA)
static void cleanup(const Factory::MakerMap::value_type &v)
void find(edm::Handle< EcalRecHitCollection > &hits, DetId thisDet, std::vector< EcalRecHitCollection::const_iterator > &hit, bool debug=false)
detail::ThreadSafeRegistry< ParameterSetID, ParameterSet, ProcessParameterSetIDCache > Registry
MVATrainerComputer * calib
std::string getName(Reflex::Type &cc)
static const HistoName names[]
tuple size
Write out results.