CMS 3D CMS Logo

/data/doxygen/doxygen-1.7.3/gen/CMSSW_4_2_8/src/RecoBTau/JetTagMVALearning/plugins/JetTagMVATreeTrainer.cc

Go to the documentation of this file.
00001 #include <functional>
00002 #include <algorithm>
00003 #include <iostream>
00004 #include <fstream>
00005 #include <vector>
00006 #include <memory>
00007 #include <cmath>
00008 #include <map>
00009 
00010 #include <boost/shared_ptr.hpp>
00011 
00012 #include <TRandom.h>
00013 #include <TString.h>
00014 #include <TFile.h>
00015 #include <TTree.h>
00016 #include <TBranch.h>
00017 #include <TLeaf.h>
00018 #include <TList.h>
00019 #include <TKey.h>
00020 
00021 #include "FWCore/Framework/interface/MakerMacros.h"
00022 #include "FWCore/Utilities/interface/Exception.h"
00023 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00024 #include "FWCore/Framework/interface/Event.h"
00025 #include "FWCore/Framework/interface/EventSetup.h"
00026 #include "FWCore/Framework/interface/Run.h"
00027 #include "FWCore/Framework/interface/ESHandle.h"
00028 #include "FWCore/Framework/interface/EDAnalyzer.h"
00029 
00030 #include "SimDataFormats/JetMatching/interface/JetFlavourMatching.h"
00031 
00032 #include "DataFormats/Common/interface/Ref.h"
00033 #include "DataFormats/BTauReco/interface/JetTagInfo.h"
00034 #include "DataFormats/BTauReco/interface/TaggingVariable.h"
00035 
00036 #include "CondFormats/PhysicsToolsObjects/interface/MVAComputer.h"
00037 #include "CondFormats/DataRecord/interface/BTauGenericMVAJetTagComputerRcd.h"
00038 
00039 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00040 
00041 #include "RecoBTau/JetTagComputer/interface/JetTagComputerRecord.h"
00042 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputer.h"
00043 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputerCache.h"
00044 #include "RecoBTau/JetTagComputer/interface/TagInfoMVACategorySelector.h"
00045 
00046 #include "EventProgress.h"
00047 
00048 using namespace reco;
00049 using namespace PhysicsTools;
00050 
00051 class Fit {
00052     public:
00053         Fit() : isFixed(true), fixedValue(0) {}
00054         Fit(double value) : isFixed(true), fixedValue(value) {}
00055         Fit(const std::string &fileName) : isFixed(false)
00056         {
00057                 std::ifstream f(fileName.c_str());
00058                 for(int i = 0; i < 7; i++)
00059                         for(int j = 0; j < 6; j++)
00060                                 f >> params[i][j];
00061         }
00062 
00063         operator bool() const { return !isFixed || fixedValue > 0.0; }
00064 
00065         double operator () (double pt, double eta, bool isRev = false) const
00066         {
00067                 if (isFixed)
00068                         return fixedValue;
00069 
00070                 double x = std::min(std::max(-1.0, eta / 2.5), 1.0);
00071                 double y = std::min(std::max(0.0, (std::log(pt + 50.0) - 4.0943445622221004) * 40.0 / 2.1 + 0.5), 36.0);
00072 
00073                 double facs[7];
00074                 for(int i = 0; i < 7; i++) {
00075                         const double *v = params[i];
00076                         facs[i] = v[0] + y * (v[1] + y * (v[2] + y * (v[3] + y * (v[4] + y * v[5]))));
00077                 }
00078 
00079                 double xs[6];
00080                 xs[0] = x * x;
00081                 xs[1] = xs[0] * xs[0];
00082                 xs[2] = xs[1] * xs[0];
00083                 xs[3] = xs[1] * xs[1];
00084                 xs[4] = xs[2] * xs[1];
00085                 xs[5] = xs[2] * xs[2];
00086 
00087                 double val =
00088                        facs[0] +
00089                        facs[1] * (2 * xs[0] - 1) +
00090                        facs[2] * (8 * xs[1] - 8 * xs[0] + 1) +
00091                        facs[3] * (32 * xs[2] - 48 * xs[1] + 18 * xs[0] - 1) +
00092                        facs[4] * (128 * xs[3] - 256 * xs[2] + 160 * xs[1] - 32 * xs[0] + 1) +
00093                        facs[5] * (512 * xs[4] - 1280 * xs[3] + 1120 * xs[2] - 400 * xs[1] + 50 * xs[0] - 1) +
00094                        facs[6] * (2048 * xs[5] - 6144 * xs[4] + 6912 * xs[3] - 3584 * xs[2] + 840 * xs[1] - 72 * xs[0] + 1);
00095                 if (isRev)
00096                         return 1.0 / val;
00097                 else
00098                         return val;
00099         }
00100 
00101     private:
00102         bool    isFixed;
00103         double  fixedValue;
00104         double  params[7][6];
00105 };
00106 
00107 class Var {
00108     public:
00109         Var(char type, TTree *tree, const char *name) :
00110                 type(type), var(getTaggingVariableName(name))
00111         {
00112                 switch(type) {
00113                     case 'D':
00114                         tree->SetBranchAddress(name, &D);
00115                         break;
00116                     case 'I':
00117                         tree->SetBranchAddress(name, &I);
00118                         break;
00119                     case 'd':
00120                         indirect = &d;
00121                         tree->SetBranchAddress(name, &indirect);
00122                         break;
00123                     case 'i':
00124                         indirect = &i;
00125                         tree->SetBranchAddress(name, &indirect);
00126                         break;
00127                 }
00128         }
00129 
00130         void fill(TaggingVariableList &list)
00131         {
00132                 switch(type) {
00133                     case 'D':
00134                         list.insert(var, D, true);
00135                         break;
00136                     case 'I':
00137                         list.insert(var, I, true);
00138                         break;
00139                     case 'd':
00140                         for(std::vector<double>::const_iterator p = d.begin();
00141                             p != d.end(); p++)
00142                                 list.insert(var, *p, true);
00143                         break;
00144                     case 'i':
00145                         for(std::vector<int>::const_iterator p = i.begin();
00146                             p != i.end(); p++)
00147                                 list.insert(var, *p, true);
00148                         break;
00149                 }
00150         }       
00151 
00152         static bool order(const boost::shared_ptr<Var> &a,
00153                           const boost::shared_ptr<Var> &b)
00154         { return a->var < b->var; }
00155 
00156     private:
00157         char                    type;
00158         TaggingVariableName     var;
00159         double                  D;
00160         int                     I;
00161         std::vector<double>     d;
00162         std::vector<int>        i;
00163         void                    *indirect;
00164 };
00165 
00166 class JetTagMVATreeTrainer : public edm::EDAnalyzer {
00167     public:
00168         explicit JetTagMVATreeTrainer(const edm::ParameterSet &params);
00169         ~JetTagMVATreeTrainer();
00170 
00171         virtual void beginRun(const edm::Run &run,
00172                               const edm::EventSetup &es);
00173 
00174         virtual void analyze(const edm::Event &event,
00175                              const edm::EventSetup &es);
00176 
00177     protected:
00178         bool isSignalFlavour(int flavour) const;
00179         bool isIgnoreFlavour(int flavour) const;
00180 
00181         std::auto_ptr<TagInfoMVACategorySelector>       categorySelector;
00182         std::auto_ptr<GenericMVAComputerCache>          computerCache;
00183 
00184         double                                          minPt;
00185         double                                          minEta;
00186         double                                          maxEta;
00187         double                                          factor;
00188         double                                          bound;
00189         double                                          signalFactor;
00190 
00191     private:
00192         std::vector<int>                                signalFlavours;
00193         std::vector<int>                                ignoreFlavours;
00194         Fit                                             weights;
00195         std::vector<Fit>                                bias;
00196         double                                          limiter;
00197         int                                             maxEvents;
00198         TRandom                                         rand;
00199 
00200         std::vector<std::string>                        fileNames;
00201 };
00202 
00203 JetTagMVATreeTrainer::JetTagMVATreeTrainer(const edm::ParameterSet &params) :
00204         minPt(params.getParameter<double>("minimumTransverseMomentum")),
00205         minEta(params.getParameter<double>("minimumPseudoRapidity")),
00206         maxEta(params.getParameter<double>("maximumPseudoRapidity")),
00207         factor(params.getParameter<double>("factor")),
00208         bound(params.getParameter<double>("bound")),
00209         signalFactor(params.getUntrackedParameter<double>("signalFactor", 1.0)),
00210         signalFlavours(params.getParameter<std::vector<int> >("signalFlavours")),
00211         ignoreFlavours(params.getParameter<std::vector<int> >("ignoreFlavours")),
00212         limiter(params.getUntrackedParameter<double>("weightThreshold", 0.0)),
00213         maxEvents(params.getUntrackedParameter<int>("maxEvents", -1)),
00214         fileNames(params.getParameter<std::vector<std::string> >("fileNames"))
00215 {
00216         std::sort(signalFlavours.begin(), signalFlavours.end());
00217         std::sort(ignoreFlavours.begin(), ignoreFlavours.end());
00218 
00219         std::vector<std::string> calibrationLabels;
00220         if (params.getParameter<bool>("useCategories")) {
00221                 categorySelector = std::auto_ptr<TagInfoMVACategorySelector>(
00222                                 new TagInfoMVACategorySelector(params));
00223 
00224                 calibrationLabels = categorySelector->getCategoryLabels();
00225         } else {
00226                 std::string calibrationRecord =
00227                         params.getParameter<std::string>("calibrationRecord");
00228 
00229                 calibrationLabels.push_back(calibrationRecord);
00230         }
00231 
00232         computerCache = std::auto_ptr<GenericMVAComputerCache>(
00233                         new GenericMVAComputerCache(calibrationLabels));
00234 
00235         weights = Fit(params.getParameter<std::string>("weightFile"));
00236 
00237         std::vector<std::string> biasFiles = params.getParameter<std::vector<std::string> >("biasFiles");
00238         for(std::vector<std::string>::const_iterator iter = biasFiles.begin();
00239             iter != biasFiles.end(); iter++) {
00240                 if (*iter == "*")
00241                         bias.push_back(Fit(1.0));
00242                 else if (*iter == "-")
00243                         bias.push_back(Fit(0.0));
00244                 else
00245                         bias.push_back(Fit(*iter));
00246         }
00247 }
00248 
00249 JetTagMVATreeTrainer::~JetTagMVATreeTrainer()
00250 {
00251 }
00252 
00253 bool JetTagMVATreeTrainer::isSignalFlavour(int flavour) const
00254 {
00255         std::vector<int>::const_iterator pos =
00256                 std::lower_bound(signalFlavours.begin(), signalFlavours.end(),
00257                                  flavour);
00258 
00259         return pos != signalFlavours.end() && *pos == flavour;
00260 }
00261 
00262 bool JetTagMVATreeTrainer::isIgnoreFlavour(int flavour) const
00263 {
00264         std::vector<int>::const_iterator pos =
00265                 std::lower_bound(ignoreFlavours.begin(), ignoreFlavours.end(),
00266                                  flavour);
00267 
00268         return pos != ignoreFlavours.end() && *pos == flavour;
00269 }
00270 
00271 void JetTagMVATreeTrainer::beginRun(const edm::Run& run,
00272                                     const edm::EventSetup& es)
00273 {
00274         rand.SetSeed(65539);
00275 }
00276 
00277 void JetTagMVATreeTrainer::analyze(const edm::Event& event,
00278                                    const edm::EventSetup& es)
00279 {
00280         // retrieve MVAComputer calibration container
00281         edm::ESHandle<Calibration::MVAComputerContainer> calibHandle;
00282         es.get<BTauGenericMVAJetTagComputerRcd>().get("trainer", calibHandle);
00283         const Calibration::MVAComputerContainer *calib = calibHandle.product();
00284 
00285         // check container for changes
00286         computerCache->update(calib);
00287         if (computerCache->isEmpty())
00288                 return;
00289 
00290         // cached array containing MVAComputer value list
00291         std::vector<Variable::Value> values;
00292         values.push_back(Variable::Value(MVATrainer::kTargetId, 0));
00293         values.push_back(Variable::Value(MVATrainer::kWeightId, 0));
00294 
00295         int nEvents = 0;
00296         for(std::vector<std::string>::const_iterator fName = fileNames.begin();
00297             fName != fileNames.end(); fName++) {
00298                 if (maxEvents >= 0 && nEvents >= maxEvents)
00299                         break;
00300 
00301                 std::auto_ptr<TFile> file(TFile::Open(fName->c_str()));
00302                 if (!file.get())
00303                         continue;
00304                 std::cout << "Opened " << *fName << std::endl;
00305 
00306                 TIter next(file->GetListOfKeys());
00307                 TObject *obj;
00308                 while((obj = next())) {
00309                         if (maxEvents >= 0 && nEvents >= maxEvents)
00310                                 break;
00311 
00312                         TTree *tree = dynamic_cast<TTree*>(file->Get(((TKey*)obj)->GetName()));
00313                         if (!tree)
00314                                 continue;
00315                         std::cout << "Tree " << tree->GetName() << std::endl;
00316 
00317                         int flavour;
00318                         tree->SetBranchAddress("flavour", &flavour);
00319 
00320                         std::vector< boost::shared_ptr<Var> > vars;
00321 
00322                         TIter branchIter(tree->GetListOfBranches());
00323                         while((obj = branchIter())) {
00324                                 TBranch *branch = dynamic_cast<TBranch*>(obj);
00325                                 if (!branch)
00326                                         continue;
00327 
00328                                 TString name = branch->GetName();
00329                                 TLeaf *leaf = dynamic_cast<TLeaf*>(
00330                                                         branch->GetLeaf(name));
00331                                 if (!leaf)
00332                                         continue;
00333 
00334                                 TString typeName = leaf->GetTypeName();
00335                                 char typeId;
00336                                 if (typeName == "Double_t")
00337                                         typeId = 'D';
00338                                 else if (typeName == "Int_t")
00339                                         typeId = 'I';
00340                                 else if (typeName == "vector<double>")
00341                                         typeId = 'd';
00342                                 else if (typeName == "vector<int>")
00343                                         typeId = 'i';
00344                                 else
00345                                         continue;
00346 
00347                                 if (getTaggingVariableName((const char *)name) ==
00348                                                 btau::lastTaggingVariable)
00349                                         continue;
00350                                 vars.push_back(boost::shared_ptr<Var>(
00351                                                 new Var(typeId, tree, name)));
00352                         }
00353                         std::sort(vars.begin(), vars.end(), &Var::order);
00354 
00355                         Long64_t entries = tree->GetEntries();
00356                         std::cout << "Entries " << entries << std::endl;
00357                         EventProgress progress(entries);
00358                         for(Long64_t entry = 0; entry < entries; entry++) {
00359                                 if (maxEvents >= 0 && nEvents >= maxEvents)
00360                                         break;
00361 
00362                                 progress.update(entry);
00363                                 tree->GetEntry(entry);
00364 
00365                                 TaggingVariableList variables;
00366                                 for(std::vector< boost::shared_ptr<Var> >::const_iterator iter = vars.begin();
00367                                     iter != vars.end(); iter++)
00368                                         (*iter)->fill(variables);
00369                                 variables.finalize();
00370 
00371                                 double jetPt = variables[btau::jetPt];
00372                                 double jetEta = variables[btau::jetEta];
00373 
00374                                 // simple jet filter
00375                                 if (jetPt < minPt ||
00376                                     std::abs(jetEta) < minEta ||
00377                                     std::abs(jetEta) > maxEta)
00378                                         continue;
00379 
00380                                 // do not train with unknown jet flavours
00381                                 if (isIgnoreFlavour(flavour))
00382                                         continue;
00383 
00384                                 // is it a b-jet?
00385                                 bool target = isSignalFlavour(flavour);
00386 
00387                                 // retrieve index of computer in case categories are used
00388                                 int index = 0;
00389                                 if (categorySelector.get()) {
00390                                         index = categorySelector->findCategory(variables);
00391                                         if (index < 0)
00392                                                 continue;
00393                                 }
00394 
00395                                 GenericMVAComputer *mvaComputer =
00396                                         computerCache->getComputer(index);
00397                                 if (!mvaComputer)
00398                                         continue;
00399 
00400                                 int idx = 0;
00401                                 if (flavour == 4)
00402                                         idx = 1;
00403                                 else if (flavour == 5 || flavour == 7)
00404                                         idx = 2;
00405                                 double pBias[3];
00406                                 for(int i = 0; i < 3; i++)
00407                                         pBias[i] = bias[i](jetPt, jetEta, i < 2);
00408                                 double weight;
00409                                 if (bias[0] && bias[1])
00410                                         weight = (idx == 0) ? 0.75 :
00411                                                  (idx == 1) ? 0.25 : 1.0;
00412                                 else
00413                                         weight = 1.0;
00414                                 weight *= jetPt / 50.0 + 1.0;
00415                                 weight /= 1.0 + std::exp((jetPt - 600.0) / 150.0);
00416                                 weight *= (1.0 - 0.01 * std::exp(0.5 * jetEta * jetEta));
00417                                 weight /= weights(jetPt, jetEta);
00418                                 weight *= pBias[0] + pBias[1] + pBias[2];
00419                                 weight /= pBias[idx];
00420 
00421                                 weight *= factor;
00422                                 if (weight > bound)
00423                                         weight = bound;
00424 
00425                                 if (idx == 2)
00426                                         weight *= signalFactor;
00427 
00428                                 if (weight < limiter) {
00429                                         if (rand.Uniform(limiter) > weight)
00430                                                 continue;
00431                                         weight = limiter;
00432                                 }
00433 
00434                                 // composite full array of MVAComputer values
00435                                 values.resize(2 + variables.size());
00436 
00437                                 std::vector<Variable::Value>::iterator insert = values.begin();
00438                                 (insert++)->setValue(target);
00439                                 (insert++)->setValue(weight);
00440 
00441                                 std::copy(mvaComputer->iterator(variables.begin()),
00442                                           mvaComputer->iterator(variables.end()), insert);
00443 
00444                                 static_cast<MVAComputer*>(mvaComputer)->eval(values);
00445 
00446                                 nEvents++;
00447                         }
00448                 }
00449         }
00450 }
00451 
00452 // the main module
00453 DEFINE_FWK_MODULE(JetTagMVATreeTrainer);