CMS 3D CMS Logo

/data/doxygen/doxygen-1.7.3/gen/CMSSW_4_2_8/src/RecoBTau/JetTagMVALearning/src/JetTagMVATrainer.cc

Go to the documentation of this file.
00001 #include <functional>
00002 #include <algorithm>
00003 #include <iostream>
00004 #include <vector>
00005 #include <memory>
00006 #include <cmath>
00007 #include <map>
00008 
00009 #include "FWCore/Utilities/interface/Exception.h"
00010 #include "FWCore/Utilities/interface/EDMException.h"
00011 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00012 #include "FWCore/Utilities/interface/InputTag.h"
00013 #include "FWCore/Framework/interface/Event.h"
00014 #include "FWCore/Framework/interface/EventSetup.h"
00015 #include "FWCore/Framework/interface/ESHandle.h"
00016 #include "FWCore/Framework/interface/EDAnalyzer.h"
00017 
00018 #include "SimDataFormats/JetMatching/interface/JetFlavourMatching.h"
00019 
00020 #include "DataFormats/Common/interface/Ref.h"
00021 #include "DataFormats/Common/interface/View.h"
00022 #include "DataFormats/BTauReco/interface/JetTagInfo.h"
00023 #include "DataFormats/BTauReco/interface/TaggingVariable.h"
00024 
00025 #include "CondFormats/PhysicsToolsObjects/interface/MVAComputer.h"
00026 #include "CondFormats/DataRecord/interface/BTauGenericMVAJetTagComputerRcd.h"
00027 
00028 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00029 
00030 #include "RecoBTau/JetTagComputer/interface/JetTagComputerRecord.h"
00031 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputer.h"
00032 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputerCache.h"
00033 #include "RecoBTau/JetTagComputer/interface/TagInfoMVACategorySelector.h"
00034 #include "RecoBTau/JetTagMVALearning/interface/JetTagMVATrainer.h"
00035 
00036 using namespace reco;
00037 using namespace PhysicsTools;
00038 
00039 static const AtomicId kJetPt(TaggingVariableTokens[btau::jetPt]);
00040 static const AtomicId kJetEta(TaggingVariableTokens[btau::jetEta]);
00041 
00042 JetTagMVATrainer::JetTagMVATrainer(const edm::ParameterSet &params) :
00043         jetFlavour(params.getParameter<edm::InputTag>("jetFlavourMatching")),
00044         minPt(params.getParameter<double>("minimumTransverseMomentum")),
00045         minEta(params.getParameter<double>("minimumPseudoRapidity")),
00046         maxEta(params.getParameter<double>("maximumPseudoRapidity")),
00047         setupDone(false),
00048         jetTagComputer(params.getParameter<std::string>("jetTagComputer")),
00049         tagInfos(params.getParameter< std::vector<edm::InputTag> >("tagInfos")),
00050         signalFlavours(params.getParameter<std::vector<int> >("signalFlavours")),
00051         ignoreFlavours(params.getParameter<std::vector<int> >("ignoreFlavours"))
00052 {
00053         std::sort(signalFlavours.begin(), signalFlavours.end());
00054         std::sort(ignoreFlavours.begin(), ignoreFlavours.end());
00055 
00056         std::vector<std::string> calibrationLabels;
00057         if (params.getParameter<bool>("useCategories")) {
00058                 categorySelector.reset(new TagInfoMVACategorySelector(params));
00059 
00060                 calibrationLabels = categorySelector->getCategoryLabels();
00061         } else {
00062                 std::string calibrationRecord =
00063                         params.getParameter<std::string>("calibrationRecord");
00064 
00065                 calibrationLabels.push_back(calibrationRecord);
00066         }
00067 
00068         computerCache.reset(new GenericMVAComputerCache(calibrationLabels));
00069 }
00070 
00071 JetTagMVATrainer::~JetTagMVATrainer()
00072 {
00073 }
00074 
00075 void JetTagMVATrainer::setup(const JetTagComputer &computer)
00076 {
00077         std::vector<std::string> inputLabels(computer.getInputLabels());
00078 
00079         if (inputLabels.empty())
00080                 inputLabels.push_back("tagInfo");
00081 
00082         if (tagInfos.size() != inputLabels.size()) {
00083                 std::string message("VInputTag size mismatch - the following "
00084                                     "taginfo labels are needed:\n");
00085                 for(std::vector<std::string>::const_iterator iter =
00086                         inputLabels.begin(); iter != inputLabels.end(); ++iter)
00087                         message += "\"" + *iter + "\"\n";
00088                 throw edm::Exception(edm::errors::Configuration) << message;
00089         }
00090 
00091         setupDone = true;
00092 }
00093 
00094 // map helper
00095 namespace {
00096         struct JetCompare :
00097                 public std::binary_function<edm::RefToBase<Jet>,
00098                                             edm::RefToBase<Jet>, bool> {
00099                 inline bool operator () (const edm::RefToBase<Jet> &j1,
00100                                          const edm::RefToBase<Jet> &j2) const
00101                 { return j1.key() < j2.key(); }
00102         };
00103 }
00104 
00105 struct JetTagMVATrainer::JetInfo {
00106         JetInfo() : flavour(0)
00107         { leptons[0] = leptons[1] = leptons[2] = 0; }
00108 
00109         unsigned int            flavour;
00110         bool                    leptons[3];
00111         std::vector<int>        tagInfos;
00112 };
00113 
00114 static bool isFlavour(int flavour, const std::vector<int> &list)
00115 {
00116         std::vector<int>::const_iterator pos =
00117                         std::lower_bound(list.begin(), list.end(), flavour);
00118 
00119         return pos != list.end() && *pos == flavour;
00120 }
00121 
00122 bool JetTagMVATrainer::isFlavour(const JetInfo &info,
00123                                  const std::vector<int> &list)
00124 {
00125         if (::isFlavour(info.flavour, list))
00126                 return true;
00127         else if (info.flavour < 4)
00128                 return false;
00129 
00130         for(unsigned int i = 1; i <= 3; i++)
00131                 if (info.leptons[i - 1] &&
00132                     ::isFlavour(info.flavour * 10 + i, list))
00133                         return true;
00134 
00135         return false;
00136 }
00137 
00138 bool JetTagMVATrainer::isSignalFlavour(const JetInfo &info) const
00139 {
00140         return isFlavour(info, signalFlavours);
00141 }
00142 
00143 bool JetTagMVATrainer::isIgnoreFlavour(const JetInfo &info) const
00144 {
00145         return isFlavour(info, ignoreFlavours);
00146 }
00147 
00148 void JetTagMVATrainer::analyze(const edm::Event& event,
00149                                const edm::EventSetup& es)
00150 {
00151         // retrieve MVAComputer calibration container
00152         edm::ESHandle<Calibration::MVAComputerContainer> calibHandle;
00153         es.get<BTauGenericMVAJetTagComputerRcd>().get("trainer", calibHandle);
00154         const Calibration::MVAComputerContainer *calib = calibHandle.product();
00155 
00156         // check container for changes
00157         computerCache->update(calib);
00158         if (computerCache->isEmpty())
00159                 return;
00160 
00161         // retrieve JetTagComputer
00162         edm::ESHandle<JetTagComputer> computerHandle;
00163         es.get<JetTagComputerRecord>().get(jetTagComputer, computerHandle);
00164         const GenericMVAJetTagComputer *computer =
00165                         dynamic_cast<const GenericMVAJetTagComputer*>(
00166                                                 computerHandle.product());
00167         if (!computer)
00168                 throw cms::Exception("InvalidCast")
00169                         << "JetTagComputer is not a MVAJetTagComputer "
00170                            "in JetTagMVATrainer" << std::endl;
00171 
00172         computer->passEventSetup(es);
00173 
00174         // finalize the JetTagMVALearning <-> JetTagComputer glue setup
00175         if (!setupDone)
00176                 setup(*computer);
00177 
00178         // retrieve TagInfos
00179         typedef std::map<edm::RefToBase<Jet>, JetInfo, JetCompare> JetInfoMap;
00180         JetInfoMap jetInfos;
00181 
00182         std::vector< edm::Handle< edm::View<BaseTagInfo> > >
00183                                         tagInfoHandles(tagInfos.size());
00184         unsigned int nTagInfos = tagInfos.size();
00185         for(unsigned int i = 0; i < nTagInfos; i++) {
00186                 edm::Handle< edm::View<BaseTagInfo> > &tagInfoHandle =
00187                                                         tagInfoHandles[i];
00188                 event.getByLabel(tagInfos[i], tagInfoHandle);
00189 
00190                 int j = 0;
00191                 for(edm::View<BaseTagInfo>::const_iterator iter =
00192                         tagInfoHandle->begin();
00193                                 iter != tagInfoHandle->end(); iter++, j++) {
00194 
00195                         JetInfo &jetInfo = jetInfos[iter->jet()];
00196                         if (jetInfo.tagInfos.empty()) {
00197                                 jetInfo.tagInfos.resize(nTagInfos, -1);
00198                         }
00199 
00200                         jetInfo.tagInfos[i] = j;
00201                 }
00202         }
00203 
00204         // retrieve jet flavours;
00205         edm::Handle<JetFlavourMatchingCollection> jetFlavourHandle;
00206         event.getByLabel(jetFlavour, jetFlavourHandle);
00207 
00208         for(JetFlavourMatchingCollection::const_iterator iter =
00209                 jetFlavourHandle->begin();
00210                                 iter != jetFlavourHandle->end(); iter++) {
00211 
00212                 JetInfoMap::iterator pos =
00213                         jetInfos.find(edm::RefToBase<Jet>(iter->first));
00214                 if (pos != jetInfos.end()) {
00215                         int flavour = iter->second.getFlavour();
00216                         flavour = std::abs(flavour);
00217                         if (flavour < 100) {
00218                                 JetFlavour::Leptons leptons =
00219                                                 iter->second.getLeptons();
00220 
00221                                 pos->second.flavour = flavour;
00222                                 pos->second.leptons[0] = leptons.electron > 0;
00223                                 pos->second.leptons[1] = leptons.muon > 0;
00224                                 pos->second.leptons[2] = leptons.tau > 0;
00225                         }
00226                 }
00227         }
00228 
00229         // cached array containing MVAComputer value list
00230         std::vector<Variable::Value> values;
00231         values.push_back(Variable::Value(MVATrainer::kTargetId, 0));
00232         values.push_back(Variable::Value(kJetPt, 0));
00233         values.push_back(Variable::Value(kJetEta, 0));
00234 
00235         // now loop over the map and compute all JetTags
00236         for(JetInfoMap::const_iterator iter = jetInfos.begin();
00237             iter != jetInfos.end(); iter++) {
00238                 edm::RefToBase<Jet> jet = iter->first;
00239                 const JetInfo &info = iter->second;
00240 
00241                 // simple jet filter
00242                 if (jet->pt() < minPt ||
00243                     std::abs(jet->eta()) < minEta ||
00244                     std::abs(jet->eta()) > maxEta)
00245                         continue;
00246 
00247                 // do not train with unknown jet flavours
00248                 if (isIgnoreFlavour(info))
00249                         continue;
00250 
00251                 // is it a b-jet?
00252                 bool target = isSignalFlavour(info);
00253 
00254                 // build TagInfos pointer for helper
00255                 std::vector<const BaseTagInfo*> tagInfoPtrs(nTagInfos);
00256                 for(unsigned int i = 0; i < nTagInfos; i++)  {
00257                         if (info.tagInfos[i] < 0)
00258                                 continue;
00259 
00260                         tagInfoPtrs[i] =
00261                                 &tagInfoHandles[i]->at(info.tagInfos[i]);
00262                 }
00263                 JetTagComputer::TagInfoHelper helper(tagInfoPtrs);
00264 
00265                 TaggingVariableList variables =
00266                                         computer->taggingVariables(helper);
00267 
00268                 // retrieve index of computer in case categories are used
00269                 int index = 0;
00270                 if (categorySelector.get()) {
00271                         index = categorySelector->findCategory(variables);
00272                         if (index < 0)
00273                                 continue;
00274                 }
00275 
00276                 GenericMVAComputer *mvaComputer =
00277                                         computerCache->getComputer(index);
00278                 if (!mvaComputer)
00279                         continue;
00280 
00281                 // composite full array of MVAComputer values
00282                 values.resize(3 + variables.size());
00283                 std::vector<Variable::Value>::iterator insert = values.begin();
00284 
00285                 (insert++)->setValue(target);
00286                 (insert++)->setValue(jet->pt());
00287                 (insert++)->setValue(jet->eta());
00288                 std::copy(mvaComputer->iterator(variables.begin()),
00289                           mvaComputer->iterator(variables.end()), insert);
00290 
00291                 static_cast<MVAComputer*>(mvaComputer)->eval(values);
00292         }
00293 }