CMS 3D CMS Logo

ProcTMVA.cc

Go to the documentation of this file.
00001 #include <unistd.h>
00002 #include <algorithm>
00003 #include <iostream>
00004 #include <sstream>
00005 #include <fstream>
00006 #include <cstddef>
00007 #include <cstring>
00008 #include <cstdio>
00009 #include <vector>
00010 #include <memory>
00011 
00012 #include <xercesc/dom/DOM.hpp>
00013 
00014 // ROOT version magic to support TMVA interface changes in newer ROOT   
00015 #include <RVersion.h>
00016 
00017 #include <TDirectory.h>
00018 #include <TTree.h>
00019 #include <TFile.h>
00020 #include <TCut.h>
00021 
00022 #include <TMVA/Types.h>
00023 #include <TMVA/Factory.h>
00024 
00025 #include "FWCore/Utilities/interface/Exception.h"
00026 
00027 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00028 #include "PhysicsTools/MVAComputer/interface/memstream.h"
00029 #include "PhysicsTools/MVAComputer/interface/zstream.h"
00030 
00031 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00032 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00033 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00034 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00035 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00036 
00037 XERCES_CPP_NAMESPACE_USE
00038 
00039 using namespace PhysicsTools;
00040 
00041 namespace { // anonymous
00042 
00043 class ROOTContextSentinel {
00044     public:
00045         ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
00046         ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
00047 
00048     private:
00049         TDirectory      *dir;
00050         TFile           *file;
00051 };
00052 
00053 class ProcTMVA : public TrainProcessor {
00054     public:
00055         typedef TrainProcessor::Registry<ProcTMVA>::Type Registry;
00056 
00057         ProcTMVA(const char *name, const AtomicId *id,
00058                  MVATrainer *trainer);
00059         virtual ~ProcTMVA();
00060 
00061         virtual void configure(DOMElement *elem);
00062         virtual Calibration::VarProcessor *getCalibration() const;
00063 
00064         virtual void trainBegin();
00065         virtual void trainData(const std::vector<double> *values,
00066                                bool target, double weight);
00067         virtual void trainEnd();
00068 
00069         virtual bool load();
00070         virtual void cleanup();
00071 
00072     private:
00073         void runTMVATrainer();
00074 
00075         struct Method {
00076                 TMVA::Types::EMVA       type;
00077                 std::string             name;
00078                 std::string             description;
00079         };
00080 
00081         std::string getTreeName() const
00082         { return trainer->getName() + '_' + (const char*)getName(); }
00083 
00084         std::string getWeightsFile(const Method &meth, const char *ext) const
00085         {
00086                 return "weights/" + getTreeName() + '_' +
00087                        meth.name + ".weights." + ext;
00088         }
00089 
00090         enum Iteration {
00091                 ITER_EXPORT,
00092                 ITER_DONE
00093         } iteration;
00094 
00095         std::vector<Method>             methods;
00096         std::vector<std::string>        names;
00097         std::auto_ptr<TFile>            file;
00098         TTree                           *treeSig, *treeBkg;
00099         Double_t                        weight;
00100         std::vector<Double_t>           vars;
00101         bool                            needCleanup;
00102         unsigned long                   nSignal;
00103         unsigned long                   nBackground;
00104         bool                            doUserTreeSetup;
00105         std::string                     setupCuts;      // cut applied by TMVA to signal and background trees
00106         std::string                     setupOptions;   // training/test tree TMVA setup options
00107 };
00108 
00109 static ProcTMVA::Registry registry("ProcTMVA");
00110 
00111 ProcTMVA::ProcTMVA(const char *name, const AtomicId *id,
00112                    MVATrainer *trainer) :
00113         TrainProcessor(name, id, trainer),
00114         iteration(ITER_EXPORT), treeSig(0), treeBkg(0), needCleanup(false),
00115         doUserTreeSetup(false), setupOptions("SplitMode = Block:!V")
00116 {
00117 }
00118 
00119 ProcTMVA::~ProcTMVA()
00120 {
00121 }
00122 
00123 void ProcTMVA::configure(DOMElement *elem)
00124 {
00125         std::vector<SourceVariable*> inputs = getInputs().get();
00126 
00127         for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
00128             iter != inputs.end(); iter++) {
00129                 std::string name = (const char*)(*iter)->getName();
00130 
00131                 if (std::find(names.begin(), names.end(), name)
00132                     != names.end()) {
00133                         for(unsigned i = 1;; i++) {
00134                                 std::ostringstream ss;
00135                                 ss << name << "_" << i;
00136                                 if (std::find(names.begin(), names.end(),
00137                                               ss.str()) == names.end()) {
00138                                         name == ss.str();
00139                                         break;
00140                                 }
00141                         }
00142                 }
00143 
00144                 names.push_back(name);
00145         }
00146 
00147         for(DOMNode *node = elem->getFirstChild();
00148             node; node = node->getNextSibling()) {
00149                 if (node->getNodeType() != DOMNode::ELEMENT_NODE)
00150                         continue;
00151 
00152                 bool isMethod = !std::strcmp(XMLSimpleStr(node->getNodeName()), "method");
00153                 bool isSetup  = !std::strcmp(XMLSimpleStr(node->getNodeName()), "setup");
00154 
00155                 if (!isMethod && !isSetup)
00156                         throw cms::Exception("ProcTMVA")
00157                                 << "Expected method or setup tag in config section."
00158                                 << std::endl;
00159 
00160                 elem = static_cast<DOMElement*>(node);
00161 
00162                 if (isMethod) {
00163                         Method method;
00164                         method.type = TMVA::Types::Instance().GetMethodType(
00165                                 XMLDocument::readAttribute<std::string>(
00166                                                         elem, "type").c_str());
00167 
00168                         method.name =
00169                                 XMLDocument::readAttribute<std::string>(
00170                                                         elem, "name");
00171 
00172                         method.description =
00173                                 (const char*)XMLSimpleStr(node->getTextContent());
00174 
00175                         methods.push_back(method);
00176                 } else if (isSetup) {
00177                         if (doUserTreeSetup)
00178                                 throw cms::Exception("ProcTMVA")
00179                                         << "Multiple appeareances of setup "
00180                                            "tag in config section."
00181                                         << std::endl;
00182 
00183                         doUserTreeSetup = true;
00184 
00185                         setupCuts = 
00186                                 XMLDocument::readAttribute<std::string>(
00187                                                         elem, "cuts");
00188                         setupOptions =
00189                                 XMLDocument::readAttribute<std::string>(
00190                                                         elem, "options");
00191                 }
00192         }
00193 
00194         if (!methods.size())
00195                 throw cms::Exception("ProcTMVA")
00196                         << "Expected TMVA method in config section."
00197                         << std::endl;
00198 }
00199 
00200 bool ProcTMVA::load()
00201 {
00202         bool ok = true;
00203         for(std::vector<Method>::const_iterator iter = methods.begin();
00204             iter != methods.end(); ++iter) {
00205                 std::ifstream in(getWeightsFile(*iter, "txt").c_str());
00206                 if (!in.good()) {
00207                         ok = false;
00208                         break;
00209                 }
00210         }
00211 
00212         if (!ok)
00213                 return false;
00214 
00215         iteration = ITER_DONE;
00216         trained = true;
00217         return true;
00218 }
00219 
00220 static std::size_t getStreamSize(std::ifstream &in)
00221 {
00222         std::ifstream::pos_type begin = in.tellg();
00223         in.seekg(0, std::ios::end);
00224         std::ifstream::pos_type end = in.tellg();
00225         in.seekg(begin, std::ios::beg);
00226 
00227         return (std::size_t)(end - begin);
00228 }
00229 
00230 Calibration::VarProcessor *ProcTMVA::getCalibration() const
00231 {
00232         Calibration::ProcTMVA *calib = new Calibration::ProcTMVA;
00233 
00234         std::ifstream in(getWeightsFile(methods[0], "txt").c_str(),
00235                          std::ios::binary | std::ios::in);
00236         if (!in.good())
00237                 throw cms::Exception("ProcTMVA")
00238                         << "Weights file " << getWeightsFile(methods[0], "txt")
00239                         << " cannot be opened for reading." << std::endl;
00240 
00241         std::size_t size = getStreamSize(in);
00242         size = size + (size / 32) + 128;
00243 
00244         char *buffer = 0;
00245         try {
00246                 buffer = new char[size];
00247                 ext::omemstream os(buffer, size);
00248                 /* call dtor of ozs at end */ {
00249                         ext::ozstream ozs(&os);
00250                         ozs << in.rdbuf();
00251                         ozs.flush();
00252                 }
00253                 size = os.end() - os.begin();
00254                 calib->store.resize(size);
00255                 std::memcpy(&calib->store.front(), os.begin(), size);
00256         } catch(...) {
00257                 delete[] buffer;
00258                 throw;
00259         }
00260         delete[] buffer;
00261         in.close();
00262 
00263         calib->method = methods[0].name;
00264         calib->variables = names;
00265 
00266         return calib;
00267 }
00268 
00269 void ProcTMVA::trainBegin()
00270 {
00271         if (iteration == ITER_EXPORT) {
00272                 ROOTContextSentinel ctx;
00273 
00274                 file = std::auto_ptr<TFile>(TFile::Open(
00275                         trainer->trainFileName(this, "root",
00276                                                "input").c_str(),
00277                         "RECREATE"));
00278                 if (!file.get())
00279                         throw cms::Exception("ProcTMVA")
00280                                 << "Could not open ROOT file for writing."
00281                                 << std::endl;
00282 
00283                 file->cd();
00284                 treeSig = new TTree((getTreeName() + "_sig").c_str(),
00285                                     "MVATrainer signal");
00286                 treeBkg = new TTree((getTreeName() + "_bkg").c_str(),
00287                                     "MVATrainer background");
00288 
00289                 treeSig->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
00290                 treeBkg->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
00291 
00292                 vars.resize(names.size());
00293 
00294                 std::vector<Double_t>::iterator pos = vars.begin();
00295                 for(std::vector<std::string>::const_iterator iter =
00296                         names.begin(); iter != names.end(); iter++, pos++) {
00297                         treeSig->Branch(iter->c_str(), &*pos,
00298                                         (*iter + "/D").c_str());
00299                         treeBkg->Branch(iter->c_str(), &*pos,
00300                                         (*iter + "/D").c_str());
00301                 }
00302 
00303                 nSignal = nBackground = 0;
00304         }
00305 }
00306 
00307 void ProcTMVA::trainData(const std::vector<double> *values,
00308                          bool target, double weight)
00309 {
00310         if (iteration != ITER_EXPORT)
00311                 return;
00312 
00313         this->weight = weight;
00314         for(unsigned int i = 0; i < vars.size(); i++, values++)
00315                 vars[i] = values->front();
00316 
00317         if (target) {
00318                 treeSig->Fill();
00319                 nSignal++;
00320         } else {
00321                 treeBkg->Fill();
00322                 nBackground++;
00323         }
00324 }
00325 
00326 void ProcTMVA::runTMVATrainer()
00327 {
00328         needCleanup = true;
00329 
00330         if (nSignal < 1 || nBackground < 1)
00331                 throw cms::Exception("ProcTMVA")
00332                         << "Not going to run TMVA: "
00333                            "No signal (" << nSignal << ") or background ("
00334                         << nBackground << ") events!" << std::endl;
00335 
00336         std::auto_ptr<TFile> file(TFile::Open(
00337                 trainer->trainFileName(this, "root", "output").c_str(),
00338                 "RECREATE"));
00339         if (!file.get())
00340                 throw cms::Exception("ProcTMVA")
00341                         << "Could not open TMVA ROOT file for writing."
00342                         << std::endl;
00343 
00344         std::auto_ptr<TMVA::Factory> factory(
00345                 new TMVA::Factory(getTreeName().c_str(), file.get(), ""));
00346 
00347         if (!factory->SetInputTrees(treeSig, treeBkg))
00348                 throw cms::Exception("ProcTMVA")
00349                         << "TMVA rejected input trees." << std::endl;
00350 
00351         for(std::vector<std::string>::const_iterator iter = names.begin();
00352             iter != names.end(); iter++)
00353                 factory->AddVariable(iter->c_str(), 'D');
00354 
00355         factory->SetWeightExpression("__WEIGHT__");
00356 
00357         if (doUserTreeSetup)
00358                 factory->PrepareTrainingAndTestTree(
00359                                         setupCuts.c_str(), setupOptions);
00360         else
00361                 factory->PrepareTrainingAndTestTree(
00362                                 "", nSignal, nBackground, 1, 1,
00363                                 "SplitMode=Block:!V");
00364 
00365         for(std::vector<Method>::const_iterator iter = methods.begin();
00366             iter != methods.end(); ++iter)
00367                 factory->BookMethod(iter->type, iter->name, iter->description);
00368 
00369         factory->TrainAllMethods();
00370         factory->TestAllMethods();
00371         factory->EvaluateAllMethods();
00372 
00373         factory.release(); // ROOT seems to take care of destruction?!
00374 
00375         file->Close();
00376 }
00377 
00378 void ProcTMVA::trainEnd()
00379 {
00380         switch(iteration) {
00381             case ITER_EXPORT:
00382                 // work around TMVA issue: fill 1 dummy sig and bkg test event
00383                 treeSig->Fill();
00384                 treeBkg->Fill();
00385 
00386                 /* ROOT context-safe */ {
00387                         ROOTContextSentinel ctx;
00388                         file->cd();
00389                         treeSig->Write();
00390                         treeBkg->Write();
00391 
00392                         file->Close();
00393                         file.reset();
00394                         file = std::auto_ptr<TFile>(TFile::Open(
00395                                 trainer->trainFileName(this, "root",
00396                                                        "input").c_str()));
00397                         if (!file.get())
00398                                 throw cms::Exception("ProcTMVA")
00399                                         << "Could not open ROOT file for "
00400                                            "reading." << std::endl;
00401                         treeSig = dynamic_cast<TTree*>(
00402                                 file->Get((getTreeName() + "_sig").c_str()));
00403                         treeBkg = dynamic_cast<TTree*>(
00404                                 file->Get((getTreeName() + "_bkg").c_str()));
00405 
00406                         runTMVATrainer();
00407 
00408                         file->Close();
00409                         treeSig = 0;
00410                         treeBkg = 0;
00411                         file.reset();
00412                 }
00413                 vars.clear();
00414 
00415                 iteration = ITER_DONE;
00416                 trained = true;
00417                 break;
00418             default:
00419                 /* shut up */;
00420         }
00421 }
00422 
00423 void ProcTMVA::cleanup()
00424 {
00425         if (!needCleanup)
00426                 return;
00427 
00428         std::remove(trainer->trainFileName(this, "root", "input").c_str());
00429         std::remove(trainer->trainFileName(this, "root", "output").c_str());
00430         for(std::vector<Method>::const_iterator iter = methods.begin();
00431             iter != methods.end(); ++iter) {
00432                 std::remove(getWeightsFile(*iter, "txt").c_str());
00433                 std::remove(getWeightsFile(*iter, "root").c_str());
00434         }
00435         rmdir("weights");
00436 }
00437 
00438 } // anonymous namespace

Generated on Tue Jun 9 17:41:31 2009 for CMSSW by  doxygen 1.5.4