CMS 3D CMS Logo

TreeSaver.cc

Go to the documentation of this file.
00001 #include <unistd.h>
00002 #include <functional>
00003 #include <algorithm>
00004 #include <iostream>
00005 #include <sstream>
00006 #include <fstream>
00007 #include <cstddef>
00008 #include <cstring>
00009 #include <cstdio>
00010 #include <vector>
00011 #include <memory>
00012 
00013 #include <xercesc/dom/DOM.hpp>
00014 
00015 #include <TDirectory.h>
00016 #include <TTree.h>
00017 #include <TFile.h>
00018 #include <TCut.h>
00019 
00020 #include "FWCore/Utilities/interface/Exception.h"
00021 
00022 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00023 #include "PhysicsTools/MVAComputer/interface/memstream.h"
00024 #include "PhysicsTools/MVAComputer/interface/zstream.h"
00025 
00026 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00027 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00028 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00029 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00030 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00031 
00032 XERCES_CPP_NAMESPACE_USE
00033 
00034 using namespace PhysicsTools;
00035 
00036 namespace { // anonymous
00037 
00038 class ROOTContextSentinel {
00039     public:
00040         ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
00041         ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
00042 
00043     private:
00044         TDirectory      *dir;
00045         TFile           *file;
00046 };
00047 
00048 class TreeSaver : public TrainProcessor {
00049     public:
00050         typedef TrainProcessor::Registry<TreeSaver>::Type Registry;
00051 
00052         TreeSaver(const char *name, const AtomicId *id,
00053                  MVATrainer *trainer);
00054         virtual ~TreeSaver();
00055 
00056         virtual void configure(DOMElement *elem);
00057         virtual void passFlags(const std::vector<Variable::Flags> &flags);
00058 
00059         virtual void trainBegin();
00060         virtual void trainData(const std::vector<double> *values,
00061                                bool target, double weight);
00062         virtual void trainEnd();
00063 
00064     private:
00065         void init();
00066 
00067         std::string getTreeName() const
00068         { return trainer->getName() + '_' + (const char*)getName(); }
00069 
00070         enum Iteration {
00071                 ITER_EXPORT,
00072                 ITER_DONE
00073         } iteration;
00074 
00075         struct Var {
00076                 std::string             name;
00077                 Variable::Flags         flags;
00078                 double                  value;
00079                 std::vector<double>     values;
00080                 std::vector<double>     *ptr;
00081 
00082                 bool hasName(std::string other) const
00083                 { return name == other; }
00084         };
00085 
00086         std::auto_ptr<TFile>            file;
00087         TTree                           *tree;
00088         Double_t                        weight;
00089         Bool_t                          target;
00090         std::vector<Var>                vars;
00091         bool                            flagsPassed, begun;
00092 };
00093 
00094 static TreeSaver::Registry registry("TreeSaver");
00095 
00096 TreeSaver::TreeSaver(const char *name, const AtomicId *id,
00097                    MVATrainer *trainer) :
00098         TrainProcessor(name, id, trainer),
00099         iteration(ITER_EXPORT), tree(0), flagsPassed(false), begun(false)
00100 {
00101 }
00102 
00103 TreeSaver::~TreeSaver()
00104 {
00105 }
00106 
00107 void TreeSaver::configure(DOMElement *elem)
00108 {
00109         std::vector<SourceVariable*> inputs = getInputs().get();
00110 
00111         for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
00112             iter != inputs.end(); iter++) {
00113                 std::string name = (const char*)(*iter)->getName();
00114 
00115                 if (std::find_if(vars.begin(), vars.end(),
00116                                  std::bind2nd(std::mem_fun_ref(&Var::hasName),
00117                                               name)) != vars.end()) {
00118                         for(unsigned i = 1;; i++) {
00119                                 std::ostringstream ss;
00120                                 ss << name << "_" << i;
00121                                 if (std::find_if(vars.begin(), vars.end(),
00122                                                  std::bind2nd(
00123                                                         std::mem_fun_ref(
00124                                                                 &Var::hasName),
00125                                                         name)) == vars.end())
00126                                         break;
00127                         }
00128                 }
00129 
00130                 Var var;
00131                 var.name = name;
00132                 var.flags = Variable::FLAG_NONE;
00133                 var.ptr = 0;
00134                 vars.push_back(var);
00135         }
00136 }
00137 
00138 void TreeSaver::init()
00139 {
00140         tree->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
00141         tree->Branch("__TARGET__", &target, "__TARGET__/O");
00142 
00143         vars.resize(vars.size());
00144 
00145         std::vector<Var>::iterator pos = vars.begin();
00146         for(std::vector<Var>::iterator iter = vars.begin();
00147             iter != vars.end(); iter++, pos++) {
00148                 if (iter->flags & Variable::FLAG_MULTIPLE) {
00149                         iter->ptr = &iter->values;
00150                         tree->Branch(iter->name.c_str(),
00151                                      "std::vector<double>",
00152                                      &pos->ptr);
00153                 } else
00154                         tree->Branch(iter->name.c_str(), &pos->value,
00155                                     (iter->name + "/D").c_str());
00156         }
00157 }
00158 
00159 void TreeSaver::passFlags(const std::vector<Variable::Flags> &flags)
00160 {
00161         assert(flags.size() == vars.size());
00162         unsigned int idx = 0;
00163         for(std::vector<Variable::Flags>::const_iterator iter = flags.begin();
00164             iter != flags.end(); ++iter, idx++)
00165                 vars[idx].flags = *iter;
00166 
00167         if (begun && !flagsPassed)
00168                 init();
00169         flagsPassed = true;
00170 }
00171 
00172 void TreeSaver::trainBegin()
00173 {
00174         if (iteration == ITER_EXPORT) {
00175                 ROOTContextSentinel ctx;
00176 
00177                 file = std::auto_ptr<TFile>(TFile::Open(
00178                         trainer->trainFileName(this, "root").c_str(),
00179                         "RECREATE"));
00180                 if (!file.get())
00181                         throw cms::Exception("TreeSaver")
00182                                 << "Could not open ROOT file for writing."
00183                                 << std::endl;
00184 
00185                 file->cd();
00186                 tree = new TTree(getTreeName().c_str(),
00187                                  "MVATrainer signal and background");
00188 
00189                 if (!begun && flagsPassed)
00190                         init();
00191                 begun = true;
00192         }
00193 }
00194 
00195 void TreeSaver::trainData(const std::vector<double> *values,
00196                          bool target, double weight)
00197 {
00198         if (iteration != ITER_EXPORT)
00199                 return;
00200 
00201         this->weight = weight;
00202         this->target = target;
00203         for(unsigned int i = 0; i < vars.size(); i++, values++) {
00204                 Var &var = vars[i];
00205                 if (var.flags & Variable::FLAG_MULTIPLE)
00206                         var.values = *values;
00207                 else if (values->empty())
00208                         var.value = -999.0;
00209                 else
00210                         var.value = values->front();
00211         }
00212 
00213         tree->Fill();
00214 }
00215 
00216 void TreeSaver::trainEnd()
00217 {
00218         switch(iteration) {
00219             case ITER_EXPORT:
00220                 /* ROOT context-safe */ {
00221                         ROOTContextSentinel ctx;
00222                         file->cd();
00223                         tree->Write();
00224                         file->Close();
00225                         file.reset();
00226                 }
00227                 vars.clear();
00228 
00229                 iteration = ITER_DONE;
00230                 trained = true;
00231                 break;
00232             default:
00233                 /* shut up */;
00234         }
00235 }
00236 
00237 } // anonymous namespace

Generated on Tue Jun 9 17:41:32 2009 for CMSSW by  doxygen 1.5.4