00001 #include <unistd.h>
00002 #include <functional>
00003 #include <algorithm>
00004 #include <iostream>
00005 #include <sstream>
00006 #include <fstream>
00007 #include <cstddef>
00008 #include <cstring>
00009 #include <cstdio>
00010 #include <vector>
00011 #include <memory>
00012
00013 #include <xercesc/dom/DOM.hpp>
00014
00015 #include <TDirectory.h>
00016 #include <TTree.h>
00017 #include <TFile.h>
00018 #include <TCut.h>
00019
00020 #include "FWCore/Utilities/interface/Exception.h"
00021
00022 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00023 #include "PhysicsTools/MVAComputer/interface/memstream.h"
00024 #include "PhysicsTools/MVAComputer/interface/zstream.h"
00025
00026 #include "PhysicsTools/MVATrainer/interface/XMLDocument.h"
00027 #include "PhysicsTools/MVATrainer/interface/XMLSimpleStr.h"
00028 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00029 #include "PhysicsTools/MVATrainer/interface/SourceVariable.h"
00030 #include "PhysicsTools/MVATrainer/interface/TrainProcessor.h"
00031
00032 XERCES_CPP_NAMESPACE_USE
00033
00034 using namespace PhysicsTools;
00035
00036 namespace {
00037
00038 class ROOTContextSentinel {
00039 public:
00040 ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
00041 ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
00042
00043 private:
00044 TDirectory *dir;
00045 TFile *file;
00046 };
00047
00048 class TreeSaver : public TrainProcessor {
00049 public:
00050 typedef TrainProcessor::Registry<TreeSaver>::Type Registry;
00051
00052 TreeSaver(const char *name, const AtomicId *id,
00053 MVATrainer *trainer);
00054 virtual ~TreeSaver();
00055
00056 virtual void configure(DOMElement *elem);
00057 virtual void passFlags(const std::vector<Variable::Flags> &flags);
00058
00059 virtual void trainBegin();
00060 virtual void trainData(const std::vector<double> *values,
00061 bool target, double weight);
00062 virtual void trainEnd();
00063
00064 private:
00065 void init();
00066
00067 std::string getTreeName() const
00068 { return trainer->getName() + '_' + (const char*)getName(); }
00069
00070 enum Iteration {
00071 ITER_EXPORT,
00072 ITER_DONE
00073 } iteration;
00074
00075 struct Var {
00076 std::string name;
00077 Variable::Flags flags;
00078 double value;
00079 std::vector<double> values;
00080 std::vector<double> *ptr;
00081
00082 bool hasName(std::string other) const
00083 { return name == other; }
00084 };
00085
00086 std::auto_ptr<TFile> file;
00087 TTree *tree;
00088 Double_t weight;
00089 Bool_t target;
00090 std::vector<Var> vars;
00091 bool flagsPassed, begun;
00092 };
00093
00094 static TreeSaver::Registry registry("TreeSaver");
00095
00096 TreeSaver::TreeSaver(const char *name, const AtomicId *id,
00097 MVATrainer *trainer) :
00098 TrainProcessor(name, id, trainer),
00099 iteration(ITER_EXPORT), tree(0), flagsPassed(false), begun(false)
00100 {
00101 }
00102
00103 TreeSaver::~TreeSaver()
00104 {
00105 }
00106
00107 void TreeSaver::configure(DOMElement *elem)
00108 {
00109 std::vector<SourceVariable*> inputs = getInputs().get();
00110
00111 for(std::vector<SourceVariable*>::const_iterator iter = inputs.begin();
00112 iter != inputs.end(); iter++) {
00113 std::string name = (const char*)(*iter)->getName();
00114
00115 if (std::find_if(vars.begin(), vars.end(),
00116 std::bind2nd(std::mem_fun_ref(&Var::hasName),
00117 name)) != vars.end()) {
00118 for(unsigned i = 1;; i++) {
00119 std::ostringstream ss;
00120 ss << name << "_" << i;
00121 if (std::find_if(vars.begin(), vars.end(),
00122 std::bind2nd(
00123 std::mem_fun_ref(
00124 &Var::hasName),
00125 name)) == vars.end())
00126 break;
00127 }
00128 }
00129
00130 Var var;
00131 var.name = name;
00132 var.flags = Variable::FLAG_NONE;
00133 var.ptr = 0;
00134 vars.push_back(var);
00135 }
00136 }
00137
00138 void TreeSaver::init()
00139 {
00140 tree->Branch("__WEIGHT__", &weight, "__WEIGHT__/D");
00141 tree->Branch("__TARGET__", &target, "__TARGET__/O");
00142
00143 vars.resize(vars.size());
00144
00145 std::vector<Var>::iterator pos = vars.begin();
00146 for(std::vector<Var>::iterator iter = vars.begin();
00147 iter != vars.end(); iter++, pos++) {
00148 if (iter->flags & Variable::FLAG_MULTIPLE) {
00149 iter->ptr = &iter->values;
00150 tree->Branch(iter->name.c_str(),
00151 "std::vector<double>",
00152 &pos->ptr);
00153 } else
00154 tree->Branch(iter->name.c_str(), &pos->value,
00155 (iter->name + "/D").c_str());
00156 }
00157 }
00158
00159 void TreeSaver::passFlags(const std::vector<Variable::Flags> &flags)
00160 {
00161 assert(flags.size() == vars.size());
00162 unsigned int idx = 0;
00163 for(std::vector<Variable::Flags>::const_iterator iter = flags.begin();
00164 iter != flags.end(); ++iter, idx++)
00165 vars[idx].flags = *iter;
00166
00167 if (begun && !flagsPassed)
00168 init();
00169 flagsPassed = true;
00170 }
00171
00172 void TreeSaver::trainBegin()
00173 {
00174 if (iteration == ITER_EXPORT) {
00175 ROOTContextSentinel ctx;
00176
00177 file = std::auto_ptr<TFile>(TFile::Open(
00178 trainer->trainFileName(this, "root").c_str(),
00179 "RECREATE"));
00180 if (!file.get())
00181 throw cms::Exception("TreeSaver")
00182 << "Could not open ROOT file for writing."
00183 << std::endl;
00184
00185 file->cd();
00186 tree = new TTree(getTreeName().c_str(),
00187 "MVATrainer signal and background");
00188
00189 if (!begun && flagsPassed)
00190 init();
00191 begun = true;
00192 }
00193 }
00194
00195 void TreeSaver::trainData(const std::vector<double> *values,
00196 bool target, double weight)
00197 {
00198 if (iteration != ITER_EXPORT)
00199 return;
00200
00201 this->weight = weight;
00202 this->target = target;
00203 for(unsigned int i = 0; i < vars.size(); i++, values++) {
00204 Var &var = vars[i];
00205 if (var.flags & Variable::FLAG_MULTIPLE)
00206 var.values = *values;
00207 else if (values->empty())
00208 var.value = -999.0;
00209 else
00210 var.value = values->front();
00211 }
00212
00213 tree->Fill();
00214 }
00215
00216 void TreeSaver::trainEnd()
00217 {
00218 switch(iteration) {
00219 case ITER_EXPORT:
00220 {
00221 ROOTContextSentinel ctx;
00222 file->cd();
00223 tree->Write();
00224 file->Close();
00225 file.reset();
00226 }
00227 vars.clear();
00228
00229 iteration = ITER_DONE;
00230 trained = true;
00231 break;
00232 default:
00233 ;
00234 }
00235 }
00236
00237 }