CMS 3D CMS Logo

/data/doxygen/doxygen-1.7.3/gen/CMSSW_4_2_8/src/RecoBTau/JetTagMVALearning/plugins/JetTagMVAExtractor.cc

Go to the documentation of this file.
00001 #include <functional>
00002 #include <algorithm>
00003 #include <iterator>
00004 #include <iostream>
00005 #include <vector>
00006 #include <memory>
00007 #include <cmath>
00008 #include <map>
00009 
00010 #include <boost/shared_ptr.hpp>
00011 
00012 #include <TDirectory.h>
00013 #include <TTree.h>
00014 #include <TFile.h>
00015 
00016 #include "FWCore/Utilities/interface/Exception.h"
00017 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00018 #include "FWCore/Utilities/interface/InputTag.h"
00019 #include "FWCore/Framework/interface/Event.h"
00020 #include "FWCore/Framework/interface/EventSetup.h"
00021 #include "FWCore/Framework/interface/ESHandle.h"
00022 #include "FWCore/Framework/interface/EDAnalyzer.h"
00023 
00024 #include "SimDataFormats/JetMatching/interface/JetFlavourMatching.h"
00025 
00026 #include "DataFormats/Common/interface/View.h"
00027 #include "DataFormats/BTauReco/interface/JetTagInfo.h"
00028 #include "DataFormats/BTauReco/interface/TaggingVariable.h"
00029 
00030 #include "CondFormats/PhysicsToolsObjects/interface/MVAComputer.h"
00031 
00032 #include "RecoBTau/JetTagComputer/interface/JetTagComputer.h"
00033 #include "RecoBTau/JetTagComputer/interface/JetTagComputerRecord.h"
00034 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputer.h"
00035 #include "RecoBTau/JetTagComputer/interface/GenericMVAJetTagComputer.h"
00036 #include "RecoBTau/JetTagComputer/interface/TagInfoMVACategorySelector.h"
00037 
00038 using namespace reco;
00039 using namespace PhysicsTools;
00040 
00041 namespace { // anonymous
00042 
00043 class ROOTContextSentinel {
00044     public:
00045         ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
00046         ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
00047 
00048     private:
00049         TDirectory      *dir;
00050         TFile           *file;
00051 };
00052 
00053 } // anonymous namespace
00054 
00055 static const AtomicId kJetPt(TaggingVariableTokens[btau::jetPt]);
00056 static const AtomicId kJetEta(TaggingVariableTokens[btau::jetEta]);
00057 
00058 class JetTagMVAExtractor : public edm::EDAnalyzer {
00059     public:
00060         explicit JetTagMVAExtractor(const edm::ParameterSet &params);
00061         ~JetTagMVAExtractor();
00062 
00063         virtual void analyze(const edm::Event &event,
00064                              const edm::EventSetup &es);
00065 
00066     private:
00067         typedef std::vector<Variable::Value> Values;
00068 
00069         struct Index {
00070                 inline Index(int flavour, int index) :
00071                         index(index), flavour(flavour) {}
00072 
00073                 inline bool operator == (const Index &rhs) const
00074                 { return index == rhs.index && flavour == rhs.flavour; }
00075 
00076                 inline bool operator < (const Index &rhs) const
00077                 { return index == rhs.index ? (flavour < rhs.flavour) : (index < rhs.index); }
00078 
00079                 int     index;
00080                 int     flavour;
00081         };
00082 
00083         struct Tree {
00084                 Tree(const JetTagMVAExtractor &main, Index index);
00085                 ~Tree();
00086 
00087                 struct Value {
00088                         Value() : type(0), multiple(false) {}
00089                         Value(char type, bool multiple) : type(type), multiple(multiple) {}
00090 
00091                         void clear() { sInt = -999; sDouble = -999.0; vInt.clear(); vDouble.clear(); }
00092                         void set(double value)
00093                         {
00094                                 if (type == 'I' && multiple)
00095                                         vInt.push_back((int)std::floor(value + 0.5));
00096                                 else if (type == 'D' && multiple)
00097                                         vDouble.push_back(value);
00098                                 else if (type == 'I' && !multiple)
00099                                         sInt = (int)std::floor(value + 0.5);
00100                                 else if (type == 'D' && !multiple)
00101                                         sDouble = value;
00102                         }
00103 
00104                         char                    type;
00105                         bool                    multiple;
00106 
00107                         void                    *indirect;
00108                         Int_t                   sInt;
00109                         Double_t                sDouble;
00110                         std::vector<int>        vInt;
00111                         std::vector<double>     vDouble;
00112                 };
00113 
00114                 int                             flavour;
00115                 TTree                           *tree;
00116                 std::auto_ptr<TFile>            file;
00117                 std::map<AtomicId, Value>       values;
00118         };
00119 
00120         struct Label {
00121                 Label() {}
00122                 Label(const edm::ParameterSet &pset);
00123                 Label(const Label &label) : variables(label.variables), label(label.label) {}
00124 
00125                 struct Var {
00126                         Var(const std::string &name);
00127 
00128                         AtomicId        id;
00129                         char            type;
00130                         bool            multiple;
00131                 };
00132 
00133                 std::vector<Var>        variables;
00134                 std::string             label;
00135         };
00136 
00137         friend class Tree;
00138 
00139         void setup(const JetTagComputer &computer);
00140         void process(Index index, const Values &values);
00141 
00142         edm::InputTag                                   jetFlavour;
00143         std::auto_ptr<TagInfoMVACategorySelector>       categorySelector;
00144 
00145         double                                          minPt;
00146         double                                          minEta;
00147         double                                          maxEta;
00148 
00149         bool                                            setupDone;
00150         std::string                                     jetTagComputer;
00151         const GenericMVAComputer                        mvaComputer;
00152 
00153         std::map<std::string, edm::InputTag>            tagInfoLabels;
00154         std::vector<edm::InputTag>                      tagInfos;
00155 
00156         std::vector<Label>                              calibrationLabels;
00157 
00158         std::map<Index, boost::shared_ptr<Tree> >       treeMap;
00159 };
00160 
00161 JetTagMVAExtractor::Tree::Tree(const JetTagMVAExtractor &main, Index index) :
00162         flavour(index.flavour)
00163 {
00164         static const char flavourMap[] = " DUSCBTB             G";
00165 
00166         if (index.index < 0 || index.index >= (int)main.calibrationLabels.size())
00167                 return;
00168 
00169         ROOTContextSentinel ctx;
00170 
00171         Label label = main.calibrationLabels[index.index];
00172 if (index.flavour > 21) std::cout << index.flavour << std::endl;
00173         std::string flavour = std::string("") + flavourMap[index.flavour];
00174         file.reset(new TFile((label.label + "_" + flavour + ".root").c_str(), "RECREATE"));
00175         file->cd();
00176 
00177         tree = new TTree(label.label.c_str(), (label.label + "_" + flavour).c_str());
00178 
00179         tree->Branch("flavour", &this->flavour, "flavour/I");
00180 
00181         for(std::vector<Label::Var>::const_iterator iter = label.variables.begin();
00182             iter != label.variables.end(); iter++) {
00183                 values[iter->id] = Value(iter->type, iter->multiple);
00184                 Value &value = values[iter->id];
00185                 const char *name = iter->id;
00186 
00187                 if (iter->type == 'I' && !iter->multiple) {
00188                         tree->Branch(name, &value.sInt,
00189                                      (std::string(name) + "/I").c_str());
00190                 } else if (iter->type == 'D' && !iter->multiple) {
00191                         tree->Branch(name, &value.sDouble,
00192                                      (std::string(name) + "/D").c_str());
00193                 } else if (iter->type == 'I' && iter->multiple) {
00194                         value.indirect = &value.vInt;
00195                         tree->Branch(name, "std::vector<int>",
00196                                      &value.indirect);
00197                 } else if (iter->type == 'D' && iter->multiple) {
00198                         value.indirect = &value.vDouble;
00199                         tree->Branch(name, "std::vector<double>",
00200                                      &value.indirect);
00201                 }
00202         }
00203 }
00204 
00205 JetTagMVAExtractor::Tree::~Tree()
00206 {
00207         if (!tree)
00208                 return;
00209 
00210         ROOTContextSentinel ctx;
00211 
00212         file->cd();
00213         tree->Write();
00214         file->Close();
00215 }
00216 
00217 JetTagMVAExtractor::Label::Label(const edm::ParameterSet &pset) :
00218         label(pset.getUntrackedParameter<std::string>("label"))
00219 {
00220         std::vector<std::string> vars =
00221                 pset.getUntrackedParameter< std::vector<std::string> >("variables");
00222         std::copy(vars.begin(), vars.end(), std::back_inserter(variables));
00223 }
00224 
00225 JetTagMVAExtractor::Label::Var::Var(const std::string &name) :
00226         id(name)
00227 {
00228         TaggingVariableName tag = getTaggingVariableName(name);
00229         if (tag == btau::lastTaggingVariable)
00230                 throw cms::Exception("UnknownTaggingVariable")
00231                         << "Unknown tagging variable " << name << std::endl;
00232 
00233         multiple = ((int)tag >= (int)btau::trackMomentum &&
00234                     (int)tag <= (int)btau::trackGhostTrackWeight) ||
00235                    ((int)tag >= (int)btau::trackP0Par &&
00236                     (int)tag <= (int)btau::algoDiscriminator);
00237 
00238         type = (tag == btau::jetNTracks ||
00239                 tag == btau::vertexCategory ||
00240                 tag == btau::jetNSecondaryVertices ||
00241                 tag == btau::vertexNTracks ||
00242                 tag == btau::trackNTotalHits ||
00243                 tag == btau::trackNPixelHits) ? 'I' : 'D';
00244 }
00245 
00246 static const Calibration::MVAComputer *dummyCalib()
00247 {
00248         static Calibration::MVAComputer dummy;
00249         static bool init = false;
00250 
00251         if (!init)
00252                 dummy.inputSet.push_back(Calibration::Variable());
00253 
00254         return &dummy;
00255 }
00256 
00257 JetTagMVAExtractor::JetTagMVAExtractor(const edm::ParameterSet &params) :
00258         jetFlavour(params.getParameter<edm::InputTag>("jetFlavourMatching")),
00259         minPt(params.getParameter<double>("minimumTransverseMomentum")),
00260         minEta(params.getParameter<double>("minimumPseudoRapidity")),
00261         maxEta(params.getParameter<double>("maximumPseudoRapidity")),
00262         setupDone(false),
00263         jetTagComputer(params.getParameter<std::string>("jetTagComputer")),
00264         mvaComputer(dummyCalib())
00265 {
00266         std::vector<std::string> labels;
00267 
00268         if (params.getParameter<bool>("useCategories")) {
00269                 categorySelector = std::auto_ptr<TagInfoMVACategorySelector>(
00270                                 new TagInfoMVACategorySelector(params));
00271 
00272                 labels = categorySelector->getCategoryLabels();
00273         } else {
00274                 std::string calibrationRecord =
00275                         params.getParameter<std::string>("calibrationRecord");
00276 
00277                 labels.push_back(calibrationRecord);
00278         }
00279 
00280         std::vector<edm::ParameterSet> variables =
00281                 params.getUntrackedParameter< std::vector<edm::ParameterSet> >("variables");
00282 
00283         std::map<std::string, Label> labelMap;
00284 
00285         for(std::vector<edm::ParameterSet>::const_iterator iter = variables.begin();
00286             iter != variables.end(); iter++) {
00287                 Label label(*iter);
00288                 if (labelMap.count(label.label))
00289                         throw cms::Exception("DuplVariables")
00290                                 << "Duplicated label for variables "
00291                                 << label.label << std::endl;
00292                 labelMap[label.label] = label;
00293         }
00294 
00295         if (labelMap.size() != labels.size())
00296                 throw cms::Exception("MismatchVariables")
00297                         << "Label variables mismatch." << std::endl;
00298 
00299         for(std::vector<std::string>::const_iterator iter = labels.begin();
00300             iter != labels.end(); iter++) {
00301                 std::map<std::string, Label>::const_iterator pos =
00302                                                         labelMap.find(*iter);
00303                 if (pos == labelMap.end())
00304                         throw cms::Exception("MismatchVariables")
00305                                 << "Variables definition for " << *iter
00306                                 << " not found." << std::endl;
00307 
00308                 calibrationLabels.push_back(pos->second);
00309         }
00310 
00311         std::vector<std::string> inputTags =
00312                         params.getParameterNamesForType<edm::InputTag>();
00313 
00314         for(std::vector<std::string>::const_iterator iter = inputTags.begin();
00315             iter != inputTags.end(); iter++)
00316                 tagInfoLabels[*iter] =
00317                                 params.getParameter<edm::InputTag>(*iter);
00318 }
00319 
00320 JetTagMVAExtractor::~JetTagMVAExtractor()
00321 {
00322 }
00323 
00324 void JetTagMVAExtractor::setup(const JetTagComputer &computer)
00325 {
00326         std::vector<std::string> inputLabels = computer.getInputLabels();
00327         if (inputLabels.empty())
00328                 inputLabels.push_back("tagInfo");
00329 
00330         for(std::vector<std::string>::const_iterator iter = inputLabels.begin();
00331             iter != inputLabels.end(); iter++) {
00332                 std::map<std::string, edm::InputTag>::const_iterator pos =
00333                                                 tagInfoLabels.find(*iter);
00334                 if (pos == tagInfoLabels.end())
00335                         throw cms::Exception("InputTagMissing")
00336                                 << "JetTagMVAExtractor is missing a TagInfo "
00337                                    "InputTag \"" << *iter << "\"" << std::endl;
00338 
00339                 tagInfos.push_back(pos->second);
00340         }
00341 
00342         setupDone = true;
00343 }
00344 
00345 // map helper
00346 namespace {
00347         struct JetCompare :
00348                 public std::binary_function<edm::RefToBase<Jet>,
00349                                             edm::RefToBase<Jet>, bool> {
00350                 inline bool operator () (const edm::RefToBase<Jet> &j1,
00351                                          const edm::RefToBase<Jet> &j2) const
00352                 { return j1.key() < j2.key(); }
00353         };
00354 
00355         struct JetInfo {
00356                 unsigned int            flavour;
00357                 std::vector<int>        tagInfos;
00358         };
00359 }
00360 
00361 void JetTagMVAExtractor::analyze(const edm::Event& event,
00362                                  const edm::EventSetup& es)
00363 {
00364         // retrieve JetTagComputer
00365         edm::ESHandle<JetTagComputer> computerHandle;
00366         es.get<JetTagComputerRecord>().get(jetTagComputer, computerHandle);
00367         const GenericMVAJetTagComputer *computer =
00368                         dynamic_cast<const GenericMVAJetTagComputer*>(
00369                                                 computerHandle.product());
00370         if (!computer)
00371                 throw cms::Exception("InvalidCast")
00372                         << "JetTagComputer is not a MVAJetTagComputer "
00373                            "in JetTagMVAExtractor" << std::endl;
00374 
00375         computer->passEventSetup(es);
00376 
00377         // finalize the JetTagMVALearning <-> JetTagComputer glue setup
00378         if (!setupDone)
00379                 setup(*computer);
00380 
00381         // retrieve TagInfos
00382         typedef edm::RefToBase<Jet> JetRef;
00383         typedef std::map<JetRef, JetInfo, JetCompare> JetInfoMap;
00384         JetInfoMap jetInfos;
00385 
00386         std::vector< edm::Handle< edm::View<BaseTagInfo> > >
00387                                         tagInfoHandles(tagInfos.size());
00388         unsigned int nTagInfos = tagInfos.size();
00389         for(unsigned int i = 0; i < nTagInfos; i++) {
00390                 edm::Handle< edm::View<BaseTagInfo> > &tagInfoHandle =
00391                                                         tagInfoHandles[i];
00392                 event.getByLabel(tagInfos[i], tagInfoHandle);
00393 
00394                 int j = 0;
00395                 for(edm::View<BaseTagInfo>::const_iterator iter =
00396                         tagInfoHandle->begin();
00397                                 iter != tagInfoHandle->end(); iter++, j++) {
00398 
00399                         JetInfo &jetInfo = jetInfos[iter->jet()];
00400                         if (jetInfo.tagInfos.empty()) {
00401                                 jetInfo.flavour = 0;
00402                                 jetInfo.tagInfos.resize(nTagInfos, -1);
00403                         }
00404 
00405                         jetInfo.tagInfos[i] = j;
00406                 }
00407         }
00408 
00409         // retrieve jet flavours;
00410         edm::Handle<JetFlavourMatchingCollection> jetFlavourHandle;
00411         event.getByLabel(jetFlavour, jetFlavourHandle);
00412 
00413         for(JetFlavourMatchingCollection::const_iterator iter =
00414                 jetFlavourHandle->begin();
00415                                 iter != jetFlavourHandle->end(); iter++) {
00416 
00417                 JetInfoMap::iterator pos = jetInfos.find(iter->first);
00418                 if (pos != jetInfos.end())
00419                         pos->second.flavour =
00420                                         std::abs(iter->second.getFlavour());
00421         }
00422 
00423         // cached array containing MVAComputer value list
00424         std::vector<Variable::Value> values;
00425         values.push_back(Variable::Value(kJetPt, 0));
00426         values.push_back(Variable::Value(kJetEta, 0));
00427 
00428         // now loop over the map and compute all JetTags
00429         for(JetInfoMap::const_iterator iter = jetInfos.begin();
00430             iter != jetInfos.end(); iter++) {
00431                 edm::RefToBase<Jet> jet = iter->first;
00432                 const JetInfo &info = iter->second;
00433 
00434                 // simple jet filter
00435                 if (jet->pt() < minPt ||
00436                     std::abs(jet->eta()) < minEta ||
00437                     std::abs(jet->eta()) > maxEta)
00438                         continue;
00439 
00440                 // do not train with unknown jet flavours
00441                 if (!info.flavour)
00442                         continue;
00443 
00444                 // build TagInfos pointer for helper
00445                 std::vector<const BaseTagInfo*> tagInfoPtrs(nTagInfos);
00446                 for(unsigned int i = 0; i < nTagInfos; i++)  {
00447                         if (info.tagInfos[i] < 0)
00448                                 continue;
00449 
00450                         tagInfoPtrs[i] =
00451                                 &tagInfoHandles[i]->at(info.tagInfos[i]);
00452                 }
00453                 JetTagComputer::TagInfoHelper helper(tagInfoPtrs);
00454 
00455                 TaggingVariableList variables =
00456                                         computer->taggingVariables(helper);
00457 
00458                 // retrieve index of computer in case categories are used
00459                 int index = 0;
00460                 if (categorySelector.get()) {
00461                         index = categorySelector->findCategory(variables);
00462                         if (index < 0)
00463                                 continue;
00464                 }
00465 
00466                 // composite full array of MVAComputer values
00467                 values.resize(2 + variables.size());
00468                 std::vector<Variable::Value>::iterator insert = values.begin();
00469 
00470                 (insert++)->setValue(jet->pt());
00471                 (insert++)->setValue(jet->eta());
00472                 std::copy(mvaComputer.iterator(variables.begin()),
00473                           mvaComputer.iterator(variables.end()), insert);
00474 
00475                 process(Index(info.flavour, index), values);
00476         }
00477 }
00478 
00479 void JetTagMVAExtractor::process(Index index, const Values &values)
00480 {
00481         if (index.flavour == 7)
00482                 index.flavour = 5;
00483 
00484         std::map<Index, boost::shared_ptr<Tree> >::iterator pos = treeMap.find(index);
00485         Tree *tree;
00486 
00487         if (pos == treeMap.end())
00488                 tree = treeMap.insert(std::make_pair(index, boost::shared_ptr<Tree>(new Tree(*this, index)))).first->second.get();
00489         else
00490                 tree = pos->second.get();
00491 
00492         if (!tree->tree)
00493                 return;
00494 
00495         for(std::map<AtomicId, Tree::Value>::iterator iter = tree->values.begin();
00496             iter != tree->values.end(); iter++)
00497                 iter->second.clear();
00498 
00499         for(Values::const_iterator iter = values.begin();
00500             iter != values.end(); iter++) {
00501                 std::map<AtomicId, Tree::Value>::iterator pos = tree->values.find(iter->getName());
00502                 if (pos == tree->values.end())
00503                         throw cms::Exception("VarNotFound")
00504                                 << "Variable " << (const char*)iter->getName()
00505                                 << " not found." << std::endl;
00506 
00507                 pos->second.set(iter->getValue());
00508         }
00509 
00510         tree->tree->Fill();
00511 }
00512 
00513 #include "FWCore/Framework/interface/MakerMacros.h"
00514 
00515 // the main module
00516 DEFINE_FWK_MODULE(JetTagMVAExtractor);