00001 #include <unistd.h>
00002 #include <algorithm>
00003 #include <iostream>
00004 #include <sstream>
00005 #include <fstream>
00006 #include <cstddef>
00007 #include <cstring>
00008 #include <cstdio>
00009 #include <vector>
00010 #include <memory>
00011
00012 #include <xercesc/dom/DOM.hpp>
00013
00014
00015 #include <RVersion.h>
00016
00017 #include <TDirectory.h>
00018 #include <TTree.h>
00019 #include <TFile.h>
00020 #include <TCut.h>
00021
00022 #include <TMVA/Types.h>
00023 #include <TMVA/Factory.h>
00024
00025 #include "FWCore/Utilities/interface/Exception.h"
00026
00027 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00028 #include "PhysicsTools/MVAComputer/interface/memstream.h"
00029 #include "PhysicsTools/MVAComputer/interface/zstream.h"
00030
00031 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00032 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00033 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00034 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00035 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00036
00037 XERCES_CPP_NAMESPACE_USE
00038
00039 using namespace PhysicsTools;
00040
00041 namespace {
00042
00043 class ROOTContextSentinel {
00044 public:
00045 ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
00046 ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
00047
00048 private:
00049 TDirectory *dir;
00050 TFile *file;
00051 };
00052
00053 class ProcTMVA : public TrainProcessor {
00054 public:
00055 typedef TrainProcessor::Registry<ProcTMVA>::Type Registry;
00056
00057 ProcTMVA(const char *name, const AtomicId *id,
00058 MVATrainer *trainer);
00059 virtual ~ProcTMVA();
00060
00061 virtual void configure(DOMElement *elem);
00062 virtual Calibration::VarProcessor *getCalibration() const;
00063
00064 virtual void trainBegin();
00065 virtual void trainData(const std::vector<double> *values,
00066 bool target, double weight);
00067 virtual void trainEnd();
00068
00069 virtual bool load();
00070 virtual void cleanup();
00071
00072 private:
00073 void runTMVATrainer();
00074
00075 struct Method {
00076 TMVA::Types::EMVA type;
00077 std::string name;
00078 std::string description;
00079 };
00080
00081 std::string getTreeName() const
00082 { return trainer->getName() + '_' + (const char*)getName(); }
00083
00084 std::string getWeightsFile(const Method &meth, const char *ext) const
00085 {
00086 return "weights/" + getTreeName() + '_' +
00087 meth.name + ".weights." + ext;
00088 }
00089
00090 enum Iteration {
00091 ITER_EXPORT,
00092 ITER_DONE
00093 } iteration;
00094
00095 std::vector<Method> methods;
00096 std::vector<std::string> names;
00097 std::auto_ptr<TFile> file;
00098 TTree *treeSig, *treeBkg;
00099 Double_t weight;
00100 std::vector<Double_t> vars;
00101 bool needCleanup;
00102 unsigned long nSignal;
00103 unsigned long nBackground;
00104 bool doUserTreeSetup;
00105 std::string setupCuts;
00106 std::string setupOptions;
00107 };
00108
00109 static ProcTMVA::Registry registry("ProcTMVA");
00110
00111 ProcTMVA::ProcTMVA(const char *name, const AtomicId *id,
00112 MVATrainer *trainer) :
00113 TrainProcessor(name, id, trainer),
00114 iteration(ITER_EXPORT), treeSig(0), treeBkg(0), needCleanup(false),
00115 doUserTreeSetup(false), setupOptions("SplitMode = Block:!V")
00116 {
00117 }
00118
00119 ProcTMVA::~ProcTMVA()
00120 {
00121 }
00122
00123 void ProcTMVA::configure(DOMElement *elem)
00124 {
00125 std::vector<SourceVariable*> inputs = getInputs().get();
00126
00127 for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
00128 iter != inputs.end(); iter++) {
00129 std::string name = (const char*)(*iter)->getName();
00130
00131 if (std::find(names.begin(), names.end(), name)
00132 != names.end()) {
00133 for(unsigned i = 1;; i++) {
00134 std::ostringstream ss;
00135 ss << name << "_" << i;
00136 if (std::find(names.begin(), names.end(),
00137 ss.str()) == names.end()) {
00138 name == ss.str();
00139 break;
00140 }
00141 }
00142 }
00143
00144 names.push_back(name);
00145 }
00146
00147 for(DOMNode *node = elem->getFirstChild();
00148 node; node = node->getNextSibling()) {
00149 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00150 continue;
00151
00152 bool isMethod = !std::strcmp(XMLSimpleStr(node->getNodeName()), "method");
00153 bool isSetup = !std::strcmp(XMLSimpleStr(node->getNodeName()), "setup");
00154
00155 if (!isMethod && !isSetup)
00156 throw cms::Exception("ProcTMVA")
00157 << "Expected method or setup tag in config section."
00158 << std::endl;
00159
00160 elem = static_cast<DOMElement*>(node);
00161
00162 if (isMethod) {
00163 Method method;
00164 method.type = TMVA::Types::Instance().GetMethodType(
00165 XMLDocument::readAttribute<std::string>(
00166 elem, "type").c_str());
00167
00168 method.name =
00169 XMLDocument::readAttribute<std::string>(
00170 elem, "name");
00171
00172 method.description =
00173 (const char*)XMLSimpleStr(node->getTextContent());
00174
00175 methods.push_back(method);
00176 } else if (isSetup) {
00177 if (doUserTreeSetup)
00178 throw cms::Exception("ProcTMVA")
00179 << "Multiple appeareances of setup "
00180 "tag in config section."
00181 << std::endl;
00182
00183 doUserTreeSetup = true;
00184
00185 setupCuts =
00186 XMLDocument::readAttribute<std::string>(
00187 elem, "cuts");
00188 setupOptions =
00189 XMLDocument::readAttribute<std::string>(
00190 elem, "options");
00191 }
00192 }
00193
00194 if (!methods.size())
00195 throw cms::Exception("ProcTMVA")
00196 << "Expected TMVA method in config section."
00197 << std::endl;
00198 }
00199
00200 bool ProcTMVA::load()
00201 {
00202 bool ok = true;
00203 for(std::vector<Method>::const_iterator iter = methods.begin();
00204 iter != methods.end(); ++iter) {
00205 std::ifstream in(getWeightsFile(*iter, "txt").c_str());
00206 if (!in.good()) {
00207 ok = false;
00208 break;
00209 }
00210 }
00211
00212 if (!ok)
00213 return false;
00214
00215 iteration = ITER_DONE;
00216 trained = true;
00217 return true;
00218 }
00219
00220 static std::size_t getStreamSize(std::ifstream &in)
00221 {
00222 std::ifstream::pos_type begin = in.tellg();
00223 in.seekg(0, std::ios::end);
00224 std::ifstream::pos_type end = in.tellg();
00225 in.seekg(begin, std::ios::beg);
00226
00227 return (std::size_t)(end - begin);
00228 }
00229
00230 Calibration::VarProcessor *ProcTMVA::getCalibration() const
00231 {
00232 Calibration::ProcTMVA *calib = new Calibration::ProcTMVA;
00233
00234 std::ifstream in(getWeightsFile(methods[0], "txt").c_str(),
00235 std::ios::binary | std::ios::in);
00236 if (!in.good())
00237 throw cms::Exception("ProcTMVA")
00238 << "Weights file " << getWeightsFile(methods[0], "txt")
00239 << " cannot be opened for reading." << std::endl;
00240
00241 std::size_t size = getStreamSize(in);
00242 size = size + (size / 32) + 128;
00243
00244 char *buffer = 0;
00245 try {
00246 buffer = new char[size];
00247 ext::omemstream os(buffer, size);
00248 {
00249 ext::ozstream ozs(&os);
00250 ozs << in.rdbuf();
00251 ozs.flush();
00252 }
00253 size = os.end() - os.begin();
00254 calib->store.resize(size);
00255 std::memcpy(&calib->store.front(), os.begin(), size);
00256 } catch(...) {
00257 delete[] buffer;
00258 throw;
00259 }
00260 delete[] buffer;
00261 in.close();
00262
00263 calib->method = methods[0].name;
00264 calib->variables = names;
00265
00266 return calib;
00267 }
00268
00269 void ProcTMVA::trainBegin()
00270 {
00271 if (iteration == ITER_EXPORT) {
00272 ROOTContextSentinel ctx;
00273
00274 file = std::auto_ptr<TFile>(TFile::Open(
00275 trainer->trainFileName(this, "root",
00276 "input").c_str(),
00277 "RECREATE"));
00278 if (!file.get())
00279 throw cms::Exception("ProcTMVA")
00280 << "Could not open ROOT file for writing."
00281 << std::endl;
00282
00283 file->cd();
00284 treeSig = new TTree((getTreeName() + "_sig").c_str(),
00285 "MVATrainer signal");
00286 treeBkg = new TTree((getTreeName() + "_bkg").c_str(),
00287 "MVATrainer background");
00288
00289 treeSig->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
00290 treeBkg->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
00291
00292 vars.resize(names.size());
00293
00294 std::vector<Double_t>::iterator pos = vars.begin();
00295 for(std::vector<std::string>::const_iterator iter =
00296 names.begin(); iter != names.end(); iter++, pos++) {
00297 treeSig->Branch(iter->c_str(), &*pos,
00298 (*iter + "/D").c_str());
00299 treeBkg->Branch(iter->c_str(), &*pos,
00300 (*iter + "/D").c_str());
00301 }
00302
00303 nSignal = nBackground = 0;
00304 }
00305 }
00306
00307 void ProcTMVA::trainData(const std::vector<double> *values,
00308 bool target, double weight)
00309 {
00310 if (iteration != ITER_EXPORT)
00311 return;
00312
00313 this->weight = weight;
00314 for(unsigned int i = 0; i < vars.size(); i++, values++)
00315 vars[i] = values->front();
00316
00317 if (target) {
00318 treeSig->Fill();
00319 nSignal++;
00320 } else {
00321 treeBkg->Fill();
00322 nBackground++;
00323 }
00324 }
00325
00326 void ProcTMVA::runTMVATrainer()
00327 {
00328 needCleanup = true;
00329
00330 if (nSignal < 1 || nBackground < 1)
00331 throw cms::Exception("ProcTMVA")
00332 << "Not going to run TMVA: "
00333 "No signal (" << nSignal << ") or background ("
00334 << nBackground << ") events!" << std::endl;
00335
00336 std::auto_ptr<TFile> file(TFile::Open(
00337 trainer->trainFileName(this, "root", "output").c_str(),
00338 "RECREATE"));
00339 if (!file.get())
00340 throw cms::Exception("ProcTMVA")
00341 << "Could not open TMVA ROOT file for writing."
00342 << std::endl;
00343
00344 std::auto_ptr<TMVA::Factory> factory(
00345 new TMVA::Factory(getTreeName().c_str(), file.get(), ""));
00346
00347 if (!factory->SetInputTrees(treeSig, treeBkg))
00348 throw cms::Exception("ProcTMVA")
00349 << "TMVA rejected input trees." << std::endl;
00350
00351 for(std::vector<std::string>::const_iterator iter = names.begin();
00352 iter != names.end(); iter++)
00353 factory->AddVariable(iter->c_str(), 'D');
00354
00355 factory->SetWeightExpression("__WEIGHT__");
00356
00357 if (doUserTreeSetup)
00358 factory->PrepareTrainingAndTestTree(
00359 setupCuts.c_str(), setupOptions);
00360 else
00361 factory->PrepareTrainingAndTestTree(
00362 "", nSignal, nBackground, 1, 1,
00363 "SplitMode=Block:!V");
00364
00365 for(std::vector<Method>::const_iterator iter = methods.begin();
00366 iter != methods.end(); ++iter)
00367 factory->BookMethod(iter->type, iter->name, iter->description);
00368
00369 factory->TrainAllMethods();
00370 factory->TestAllMethods();
00371 factory->EvaluateAllMethods();
00372
00373 factory.release();
00374
00375 file->Close();
00376 }
00377
00378 void ProcTMVA::trainEnd()
00379 {
00380 switch(iteration) {
00381 case ITER_EXPORT:
00382
00383 treeSig->Fill();
00384 treeBkg->Fill();
00385
00386 {
00387 ROOTContextSentinel ctx;
00388 file->cd();
00389 treeSig->Write();
00390 treeBkg->Write();
00391
00392 file->Close();
00393 file.reset();
00394 file = std::auto_ptr<TFile>(TFile::Open(
00395 trainer->trainFileName(this, "root",
00396 "input").c_str()));
00397 if (!file.get())
00398 throw cms::Exception("ProcTMVA")
00399 << "Could not open ROOT file for "
00400 "reading." << std::endl;
00401 treeSig = dynamic_cast<TTree*>(
00402 file->Get((getTreeName() + "_sig").c_str()));
00403 treeBkg = dynamic_cast<TTree*>(
00404 file->Get((getTreeName() + "_bkg").c_str()));
00405
00406 runTMVATrainer();
00407
00408 file->Close();
00409 treeSig = 0;
00410 treeBkg = 0;
00411 file.reset();
00412 }
00413 vars.clear();
00414
00415 iteration = ITER_DONE;
00416 trained = true;
00417 break;
00418 default:
00419 ;
00420 }
00421 }
00422
00423 void ProcTMVA::cleanup()
00424 {
00425 if (!needCleanup)
00426 return;
00427
00428 std::remove(trainer->trainFileName(this, "root", "input").c_str());
00429 std::remove(trainer->trainFileName(this, "root", "output").c_str());
00430 for(std::vector<Method>::const_iterator iter = methods.begin();
00431 iter != methods.end(); ++iter) {
00432 std::remove(getWeightsFile(*iter, "txt").c_str());
00433 std::remove(getWeightsFile(*iter, "root").c_str());
00434 }
00435 rmdir("weights");
00436 }
00437
00438 }