CMS 3D CMS Logo

/afs/cern.ch/work/a/aaltunda/public/www/CMSSW_6_2_5/src/PhysicsTools/MVAComputer/src/TreeReader.cc

Go to the documentation of this file.
00001 #include <stdint.h>
00002 #include <utility>
00003 #include <cstring>
00004 #include <string>
00005 #include <vector>
00006 #include <map>
00007 
00008 #include <TString.h>
00009 #include <TTree.h>
00010 #include <TBranch.h>
00011 #include <TLeaf.h>
00012 #include <TList.h>
00013 #include <TKey.h>
00014 
00015 #include "FWCore/Utilities/interface/Exception.h"
00016 
00017 #include "PhysicsTools/MVAComputer/interface/AtomicId.h"
00018 #include "PhysicsTools/MVAComputer/interface/MVAComputer.h"
00019 #include "PhysicsTools/MVAComputer/interface/TreeReader.h"
00020 
00021 namespace PhysicsTools {
00022 
00023 const double TreeReader::kOptVal = -999.0;
00024 
00025 TreeReader::TreeReader() :
00026         tree(0), upToDate(false)
00027 {
00028 }
00029 
00030 TreeReader::TreeReader(const TreeReader &orig)
00031 {
00032         this->operator = (orig);
00033 }
00034 
00035 TreeReader::TreeReader(TTree *tree, bool skipTarget, bool skipWeight) :
00036         tree(tree), upToDate(false)
00037 {
00038         automaticAdd(skipTarget, skipWeight);
00039 }
00040 
00041 TreeReader::~TreeReader()
00042 {
00043 }
00044 
00045 TreeReader &TreeReader::operator = (const TreeReader &orig)
00046 {
00047         reset();
00048 
00049         tree = orig.tree;
00050 
00051         multiDouble.resize(orig.multiDouble.size());
00052         multiFloat.resize(orig.multiFloat.size());
00053         multiInt.resize(orig.multiInt.size());
00054         multiBool.resize(orig.multiBool.size());
00055 
00056         singleDouble.resize(orig.singleDouble.size());
00057         singleFloat.resize(orig.singleFloat.size());
00058         singleInt.resize(orig.singleInt.size());
00059         singleBool.resize(orig.singleBool.size());
00060 
00061         valueMap = orig.valueMap;
00062 
00063         return *this;
00064 }
00065 
00066 void TreeReader::setTree(TTree *tree)
00067 {
00068         this->tree = tree;
00069         upToDate = false;
00070 }
00071 
00072 void TreeReader::addBranch(const std::string &expression,
00073                            AtomicId name, bool opt)
00074 {
00075         if (!tree)
00076                 throw cms::Exception("NoTreeAvailable")
00077                         << "No TTree set in TreeReader::addBranch."
00078                         << std::endl;
00079 
00080         TBranch *branch = tree->GetBranch(expression.c_str());
00081         if (!branch)
00082                 throw cms::Exception("BranchMissing")
00083                         << "Tree branch \"" << expression << "\" missing."
00084                         << std::endl;
00085 
00086         addBranch(branch, name, opt);
00087 }
00088 
00089 void TreeReader::addBranch(TBranch *branch, AtomicId name, bool opt)
00090 {
00091         TString branchName = branch->GetName();
00092         if (!name)
00093                 name = (const char*)branchName;
00094 
00095         TLeaf *leaf = dynamic_cast<TLeaf*>(branch->GetLeaf(branchName));
00096         if (!leaf)
00097                 throw cms::Exception("InvalidBranch")
00098                         << "Tree branch \"" << branchName << "\" has no leaf."
00099                         << std::endl;
00100 
00101         TString typeName = leaf->GetTypeName();
00102         char typeId = 0;
00103         bool multi = false;
00104         if (typeName == "Double_t" || typeName == "double")
00105                 typeId = 'D';
00106         else if (typeName == "Float_t" || typeName == "float")
00107                 typeId = 'F';
00108         else if (typeName == "Int_t" || typeName == "int")
00109                 typeId = 'I';
00110         else if (typeName == "Bool_t" || typeName == "bool")
00111                 typeId = 'B';
00112         else {
00113                 multi = true;
00114                 if (typeName == "vector<double>" ||
00115                     typeName == "Vector<Double_t>")
00116                         typeId = 'D';
00117                 else if (typeName == "vector<float>" ||
00118                          typeName == "Vector<Float_t>")
00119                         typeId = 'F';
00120                 else if (typeName == "vector<int>" ||
00121                          typeName == "Vector<Int_t>")
00122                         typeId = 'I';
00123                 else if (typeName == "vector<bool>" ||
00124                          typeName == "Vector<Bool_t>")
00125                         typeId = 'B';
00126         }
00127 
00128         if (!typeId)
00129                 throw cms::Exception("InvalidBranch")
00130                         << "Tree branch \"" << branchName << "\" is of "
00131                            "unsupported type \"" << typeName << "\"."
00132                         << std::endl;
00133 
00134         if (multi)
00135                 addTypeMulti(name, 0, typeId);
00136         else
00137                 addTypeSingle(name, 0, typeId, opt);
00138 
00139         valueMap[name].setBranchName(branch->GetName());
00140 }
00141 
00142 void TreeReader::setOptional(AtomicId name, bool opt, double optVal)
00143 {
00144         std::map<AtomicId, Value>::iterator pos = valueMap.find(name);
00145         if (pos == valueMap.end())
00146                 throw cms::Exception("UnknownVariable")
00147                         << "Variable \"" <<name << "\" is not known to the "
00148                            "TreeReader." << std::endl;
00149 
00150         pos->second.setOpt(opt, optVal);
00151 }
00152 
00153 void TreeReader::addTypeSingle(AtomicId name, const void *value, char type, bool opt)
00154 {
00155         std::map<AtomicId, Value>::const_iterator pos = valueMap.find(name);
00156         if (pos != valueMap.end())
00157                 throw cms::Exception("DuplicateVariable")
00158                         << "Duplicate Variable \"" << name << "\"."
00159                         << std::endl;
00160 
00161         if (type != 'D' && type != 'F' && type != 'I' && type != 'B')
00162                 throw cms::Exception("InvalidType")
00163                         << "Unsupported type '" << type << "' in call to"
00164                            "TreeReader::addTypeSingle." << std::endl;
00165 
00166         int index = -1;
00167         if (!value) {
00168                 switch(type) {
00169                     case 'D':
00170                         index = (int)singleDouble.size();
00171                         singleDouble.push_back(Double_t());
00172                         break;
00173                     case 'F':
00174                         index = (int)singleFloat.size();
00175                         singleFloat.push_back(Float_t());
00176                         break;
00177                     case 'I':
00178                         index = (int)singleInt.size();
00179                         singleInt.push_back(Int_t());
00180                         break;
00181                     case 'B':
00182                         index = (int)singleBool.size();
00183                         singleBool.push_back(Bool());
00184                         break;
00185                 }
00186         }
00187 
00188         valueMap[name] = Value(index, false, opt, type);
00189         if (value)
00190                 valueMap[name].setPtr(value);
00191 
00192         upToDate = false;
00193 }
00194 
00195 template<typename T>
00196 static std::pair<void*, std::vector<T> > makeMulti()
00197 { return std::pair<void*, std::vector<T> >(0, std::vector<T>()); }
00198 
00199 void TreeReader::addTypeMulti(AtomicId name, const void *value, char type)
00200 {
00201         std::map<AtomicId, Value>::const_iterator pos = valueMap.find(name);
00202         if (pos != valueMap.end())
00203                 throw cms::Exception("DuplicateVariable")
00204                         << "Duplicate Variable \"" << name << "\"."
00205                         << std::endl;
00206 
00207         if (type != 'D' && type != 'F' && type != 'I' && type != 'B')
00208                 throw cms::Exception("InvalidType")
00209                         << "Unsupported type '" << type << "' in call to"
00210                            "TreeReader::addTypeMulti." << std::endl;
00211 
00212         int index = -1;
00213         if (!value) {
00214                 switch(type) {
00215                     case 'D':
00216                         index = (int)multiDouble.size();
00217                         multiDouble.push_back(makeMulti<Double_t>());
00218                         break;
00219                     case 'F':
00220                         index = (int)multiFloat.size();
00221                         multiFloat.push_back(makeMulti<Float_t>());
00222                         break;
00223                     case 'I':
00224                         index = (int)multiInt.size();
00225                         multiInt.push_back(makeMulti<Int_t>());
00226                         break;
00227                     case 'B':
00228                         index = (int)multiBool.size();
00229                         multiBool.push_back(makeMulti<Bool_t>());
00230                         break;
00231                 }
00232         }
00233 
00234         valueMap[name] = Value(index, true, false, type);
00235         if (value)
00236                 valueMap[name].setPtr(value);
00237 
00238         upToDate = false;
00239 }
00240 
00241 void TreeReader::automaticAdd(bool skipTarget, bool skipWeight)
00242 {
00243         if (!tree)
00244                 throw cms::Exception("NoTreeAvailable")
00245                         << "No TTree set in TreeReader::automaticAdd."
00246                         << std::endl;
00247 
00248         TIter iter(tree->GetListOfBranches());
00249         TObject *obj;
00250         while((obj = iter())) {
00251                 TBranch *branch = dynamic_cast<TBranch*>(obj);
00252                 if (!branch)
00253                         continue;
00254 
00255                 if (skipTarget &&
00256                     !std::strcmp(branch->GetName(), "__TARGET__"))
00257                         continue;
00258 
00259                 if (skipWeight &&
00260                     !std::strcmp(branch->GetName(), "__WEIGHT__"))
00261                         continue;
00262 
00263                 addBranch(branch);
00264         }
00265 }
00266 
00267 void TreeReader::reset()
00268 {
00269         multiDouble.clear();
00270         multiFloat.clear();
00271         multiInt.clear();
00272         multiBool.clear();
00273 
00274         singleDouble.clear();
00275         singleFloat.clear();
00276         singleInt.clear();
00277         singleBool.clear();
00278 
00279         valueMap.clear();
00280 
00281         upToDate = false;
00282 }
00283 
00284 void TreeReader::update()
00285 {
00286         if (!tree)
00287                 throw cms::Exception("NoTreeAvailable")
00288                         << "No TTree set in TreeReader::automaticAdd."
00289                         << std::endl;
00290 
00291         for(std::map<AtomicId, Value>::iterator iter = valueMap.begin();
00292             iter != valueMap.end(); iter++)
00293                 iter->second.update(this);
00294 
00295         upToDate = true;
00296 }
00297 
00298 uint64_t TreeReader::loop(const MVAComputer *mva)
00299 {
00300         if (!tree)
00301                 throw cms::Exception("NoTreeAvailable")
00302                         << "No TTree set in TreeReader::automaticAdd."
00303                         << std::endl;
00304 
00305         if (!upToDate)
00306                 update();
00307 
00308         Long64_t entries = tree->GetEntries();
00309         for(Long64_t entry = 0; entry < entries; entry++)
00310         {
00311                 tree->GetEntry(entry);
00312                 fill(mva);
00313         }
00314 
00315         return entries;
00316 }
00317 
00318 double TreeReader::fill(const MVAComputer *mva)
00319 {
00320         for(std::map<AtomicId, Value>::const_iterator iter = valueMap.begin();
00321             iter != valueMap.end(); iter++)
00322                 iter->second.fill(iter->first, this);
00323 
00324         double result = mva->eval(values);
00325         values.clear();
00326 
00327         return result;
00328 }
00329 
00330 Variable::ValueList TreeReader::fill()
00331 {
00332         for(std::map<AtomicId, Value>::const_iterator iter = valueMap.begin();
00333             iter != valueMap.end(); iter++)
00334                 iter->second.fill(iter->first, this);
00335 
00336         Variable::ValueList result = values;
00337         values.clear();
00338 
00339         return result;
00340 }
00341 
00342 std::vector<AtomicId> TreeReader::variables() const
00343 {
00344         std::vector<AtomicId> result;
00345         for(std::map<AtomicId, Value>::const_iterator iter = valueMap.begin();
00346             iter != valueMap.end(); iter++)
00347                 result.push_back(iter->first);
00348 
00349         return result;
00350 }
00351 
00352 void TreeReader::Value::update(TreeReader *reader) const
00353 {
00354         if (ptr)
00355                 return;
00356 
00357         void *value = 0;
00358         if (multiple) {
00359                 switch(type) {
00360                     case 'D':
00361                         reader->multiDouble[index].first =
00362                                 &reader->multiDouble[index].second;
00363                         value = &reader->multiDouble[index].first;
00364                         break;
00365                     case 'F':
00366                         reader->multiFloat[index].first =
00367                                 &reader->multiFloat[index].second;
00368                         value = &reader->multiFloat[index].first;
00369                         break;
00370                     case 'I':
00371                         reader->multiInt[index].first =
00372                                 &reader->multiInt[index].second;
00373                         value = &reader->multiInt[index].first;
00374                         break;
00375                     case 'B':
00376                         reader->multiBool[index].first = value;
00377                                 &reader->multiBool[index].second;
00378                         value = &reader->multiBool[index].first;
00379                         break;
00380                 }
00381         } else {
00382                 switch(type) {
00383                     case 'D':
00384                         value = &reader->singleDouble[index];
00385                         break;
00386                     case 'F':
00387                         value = &reader->singleFloat[index];
00388                         break;
00389                     case 'I':
00390                         value = &reader->singleInt[index];
00391                         break;
00392                     case 'B':
00393                         value = &reader->singleBool[index];
00394                         break;
00395                 }
00396         }
00397 
00398         reader->tree->SetBranchAddress(name, value);
00399 }
00400 
00401 void TreeReader::Value::fill(AtomicId name, TreeReader *reader) const
00402 {
00403         if (multiple) {
00404                 switch(type) {
00405                     case 'D': {
00406                         const std::vector<Double_t> *values =
00407                                 static_cast<const std::vector<Double_t>*>(ptr);
00408                         if (!values)
00409                                 values = &reader->multiDouble[index].second;
00410                         for(std::vector<Double_t>::const_iterator iter =
00411                                 values->begin(); iter != values->end(); iter++)
00412                                 reader->values.add(name, *iter);
00413                         break;
00414                     }
00415                     case 'F': {
00416                         const std::vector<Float_t> *values =
00417                                 static_cast<const std::vector<Float_t>*>(ptr);
00418                         if (!values)
00419                                 values = &reader->multiFloat[index].second;
00420                         for(std::vector<Float_t>::const_iterator iter =
00421                                 values->begin(); iter != values->end(); iter++)
00422                                 reader->values.add(name, *iter);
00423                         break;
00424                     }
00425                     case 'I': {
00426                         const std::vector<Int_t> *values =
00427                                 static_cast<const std::vector<Int_t>*>(ptr);
00428                         if (!values)
00429                                 values = &reader->multiInt[index].second;
00430                         for(std::vector<Int_t>::const_iterator iter =
00431                                 values->begin(); iter != values->end(); iter++)
00432                                 reader->values.add(name, *iter);
00433                         break;
00434                     }
00435                     case 'B': {
00436                         const std::vector<Bool_t> *values =
00437                                 static_cast<const std::vector<Bool_t>*>(ptr);
00438                         if (!values)
00439                                 values = &reader->multiBool[index].second;
00440                         for(std::vector<Bool_t>::const_iterator iter =
00441                                 values->begin(); iter != values->end(); iter++)
00442                                 reader->values.add(name, *iter);
00443                         break;
00444                     }
00445                 }
00446         } else {
00447                 double value = 0.0;
00448 
00449                 switch(type) {
00450                     case 'D':
00451                         value = ptr ? *(const Double_t*)ptr
00452                                     : reader->singleDouble[index];
00453                         break;
00454                     case 'F':
00455                         value = ptr ? *(const Float_t*)ptr
00456                                     : reader->singleFloat[index];
00457                         break;
00458                     case 'I':
00459                         value = ptr ? *(const Int_t*)ptr
00460                                     : reader->singleInt[index];
00461                         break;
00462                     case 'B':
00463                         value = ptr ? *(const Bool_t*)ptr
00464                                     : reader->singleBool[index];
00465                         break;
00466                 }
00467 
00468                 if (!optional || value != optVal)
00469                         reader->values.add(name, value);
00470         }
00471 }
00472 
00473 #define TREEREADER_ADD_IMPL(T, C) \
00474 template<> \
00475 void TreeReader::addSingle<T>(AtomicId name, const T *value, bool opt) \
00476 { addTypeSingle(name, value, C, opt); } \
00477 \
00478 template<> \
00479 void TreeReader::addMulti(AtomicId name, const std::vector<T> *value) \
00480 { addTypeMulti(name, value, C); }
00481 
00482 TREEREADER_ADD_IMPL(Double_t, 'D')
00483 TREEREADER_ADD_IMPL(Float_t, 'F')
00484 TREEREADER_ADD_IMPL(Int_t, 'I')
00485 TREEREADER_ADD_IMPL(Bool_t, 'B')
00486 
00487 #undef TREEREADER_ADD_IMPL
00488 
00489 } // namespace PhysicsTools