Go to the documentation of this file.00001 #include <functional>
00002 #include <algorithm>
00003 #include <iostream>
00004 #include <vector>
00005 #include <memory>
00006 #include <cmath>
00007 #include <map>
00008
00009 #include "FWCore/Utilities/interface/Exception.h"
00010 #include "FWCore/Utilities/interface/EDMException.h"
00011 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00012 #include "FWCore/Utilities/interface/InputTag.h"
00013 #include "FWCore/Framework/interface/Event.h"
00014 #include "FWCore/Framework/interface/EventSetup.h"
00015 #include "FWCore/Framework/interface/ESHandle.h"
00016 #include "FWCore/Framework/interface/EDAnalyzer.h"
00017
00018 #include "SimDataFormats/JetMatching/interface/JetFlavourMatching.h"
00019
00020 #include "DataFormats/Common/interface/Ref.h"
00021 #include "DataFormats/Common/interface/View.h"
00022 #include "DataFormats/BTauReco/interface/JetTagInfo.h"
00023 #include "DataFormats/BTauReco/interface/TaggingVariable.h"
00024
00025 #include "CondFormats/PhysicsToolsObjects/interface/MVAComputer.h"
00026 #include "CondFormats/DataRecord/interface/BTauGenericMVAJetTagComputerRcd.h"
00027
00028 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00029
00030 #include "RecoBTau/JetTagComputer/interface/JetTagComputerRecord.h"
00031 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputer.h"
00032 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputerCache.h"
00033 #include "RecoBTau/JetTagComputer/interface/TagInfoMVACategorySelector.h"
00034 #include "RecoBTau/JetTagMVALearning/interface/JetTagMVATrainer.h"
00035
00036 using namespace reco;
00037 using namespace PhysicsTools;
00038
00039 static const AtomicId kJetPt(TaggingVariableTokens[btau::jetPt]);
00040 static const AtomicId kJetEta(TaggingVariableTokens[btau::jetEta]);
00041
00042 JetTagMVATrainer::JetTagMVATrainer(const edm::ParameterSet ¶ms) :
00043 jetFlavour(params.getParameter<edm::InputTag>("jetFlavourMatching")),
00044 minPt(params.getParameter<double>("minimumTransverseMomentum")),
00045 minEta(params.getParameter<double>("minimumPseudoRapidity")),
00046 maxEta(params.getParameter<double>("maximumPseudoRapidity")),
00047 setupDone(false),
00048 jetTagComputer(params.getParameter<std::string>("jetTagComputer")),
00049 tagInfos(params.getParameter< std::vector<edm::InputTag> >("tagInfos")),
00050 signalFlavours(params.getParameter<std::vector<int> >("signalFlavours")),
00051 ignoreFlavours(params.getParameter<std::vector<int> >("ignoreFlavours"))
00052 {
00053 std::sort(signalFlavours.begin(), signalFlavours.end());
00054 std::sort(ignoreFlavours.begin(), ignoreFlavours.end());
00055
00056 std::vector<std::string> calibrationLabels;
00057 if (params.getParameter<bool>("useCategories")) {
00058 categorySelector.reset(new TagInfoMVACategorySelector(params));
00059
00060 calibrationLabels = categorySelector->getCategoryLabels();
00061 } else {
00062 std::string calibrationRecord =
00063 params.getParameter<std::string>("calibrationRecord");
00064
00065 calibrationLabels.push_back(calibrationRecord);
00066 }
00067
00068 computerCache.reset(new GenericMVAComputerCache(calibrationLabels));
00069 }
00070
00071 JetTagMVATrainer::~JetTagMVATrainer()
00072 {
00073 }
00074
00075 void JetTagMVATrainer::setup(const JetTagComputer &computer)
00076 {
00077 std::vector<std::string> inputLabels(computer.getInputLabels());
00078
00079 if (inputLabels.empty())
00080 inputLabels.push_back("tagInfo");
00081
00082 if (tagInfos.size() != inputLabels.size()) {
00083 std::string message("VInputTag size mismatch - the following "
00084 "taginfo labels are needed:\n");
00085 for(std::vector<std::string>::const_iterator iter =
00086 inputLabels.begin(); iter != inputLabels.end(); ++iter)
00087 message += "\"" + *iter + "\"\n";
00088 throw edm::Exception(edm::errors::Configuration) << message;
00089 }
00090
00091 setupDone = true;
00092 }
00093
00094
00095 namespace {
00096 struct JetCompare :
00097 public std::binary_function<edm::RefToBase<Jet>,
00098 edm::RefToBase<Jet>, bool> {
00099 inline bool operator () (const edm::RefToBase<Jet> &j1,
00100 const edm::RefToBase<Jet> &j2) const
00101 { return j1.key() < j2.key(); }
00102 };
00103 }
00104
00105 struct JetTagMVATrainer::JetInfo {
00106 JetInfo() : flavour(0)
00107 { leptons[0] = leptons[1] = leptons[2] = 0; }
00108
00109 unsigned int flavour;
00110 bool leptons[3];
00111 std::vector<int> tagInfos;
00112 };
00113
00114 static bool isFlavour(int flavour, const std::vector<int> &list)
00115 {
00116 std::vector<int>::const_iterator pos =
00117 std::lower_bound(list.begin(), list.end(), flavour);
00118
00119 return pos != list.end() && *pos == flavour;
00120 }
00121
00122 bool JetTagMVATrainer::isFlavour(const JetInfo &info,
00123 const std::vector<int> &list)
00124 {
00125 if (::isFlavour(info.flavour, list))
00126 return true;
00127 else if (info.flavour < 4)
00128 return false;
00129
00130 for(unsigned int i = 1; i <= 3; i++)
00131 if (info.leptons[i - 1] &&
00132 ::isFlavour(info.flavour * 10 + i, list))
00133 return true;
00134
00135 return false;
00136 }
00137
00138 bool JetTagMVATrainer::isSignalFlavour(const JetInfo &info) const
00139 {
00140 return isFlavour(info, signalFlavours);
00141 }
00142
00143 bool JetTagMVATrainer::isIgnoreFlavour(const JetInfo &info) const
00144 {
00145 return isFlavour(info, ignoreFlavours);
00146 }
00147
00148 void JetTagMVATrainer::analyze(const edm::Event& event,
00149 const edm::EventSetup& es)
00150 {
00151
00152 edm::ESHandle<Calibration::MVAComputerContainer> calibHandle;
00153 es.get<BTauGenericMVAJetTagComputerRcd>().get("trainer", calibHandle);
00154 const Calibration::MVAComputerContainer *calib = calibHandle.product();
00155
00156
00157 computerCache->update(calib);
00158 if (computerCache->isEmpty())
00159 return;
00160
00161
00162 edm::ESHandle<JetTagComputer> computerHandle;
00163 es.get<JetTagComputerRecord>().get(jetTagComputer, computerHandle);
00164 const GenericMVAJetTagComputer *computer =
00165 dynamic_cast<const GenericMVAJetTagComputer*>(
00166 computerHandle.product());
00167 if (!computer)
00168 throw cms::Exception("InvalidCast")
00169 << "JetTagComputer is not a MVAJetTagComputer "
00170 "in JetTagMVATrainer" << std::endl;
00171
00172 computer->passEventSetup(es);
00173
00174
00175 if (!setupDone)
00176 setup(*computer);
00177
00178
00179 typedef std::map<edm::RefToBase<Jet>, JetInfo, JetCompare> JetInfoMap;
00180 JetInfoMap jetInfos;
00181
00182 std::vector< edm::Handle< edm::View<BaseTagInfo> > >
00183 tagInfoHandles(tagInfos.size());
00184 unsigned int nTagInfos = tagInfos.size();
00185 for(unsigned int i = 0; i < nTagInfos; i++) {
00186 edm::Handle< edm::View<BaseTagInfo> > &tagInfoHandle =
00187 tagInfoHandles[i];
00188 event.getByLabel(tagInfos[i], tagInfoHandle);
00189
00190 int j = 0;
00191 for(edm::View<BaseTagInfo>::const_iterator iter =
00192 tagInfoHandle->begin();
00193 iter != tagInfoHandle->end(); iter++, j++) {
00194
00195 JetInfo &jetInfo = jetInfos[iter->jet()];
00196 if (jetInfo.tagInfos.empty()) {
00197 jetInfo.tagInfos.resize(nTagInfos, -1);
00198 }
00199
00200 jetInfo.tagInfos[i] = j;
00201 }
00202 }
00203
00204
00205 edm::Handle<JetFlavourMatchingCollection> jetFlavourHandle;
00206 event.getByLabel(jetFlavour, jetFlavourHandle);
00207
00208 for(JetFlavourMatchingCollection::const_iterator iter =
00209 jetFlavourHandle->begin();
00210 iter != jetFlavourHandle->end(); iter++) {
00211
00212 JetInfoMap::iterator pos =
00213 jetInfos.find(edm::RefToBase<Jet>(iter->first));
00214 if (pos != jetInfos.end()) {
00215 int flavour = iter->second.getFlavour();
00216 flavour = std::abs(flavour);
00217 if (flavour < 100) {
00218 JetFlavour::Leptons leptons =
00219 iter->second.getLeptons();
00220
00221 pos->second.flavour = flavour;
00222 pos->second.leptons[0] = leptons.electron > 0;
00223 pos->second.leptons[1] = leptons.muon > 0;
00224 pos->second.leptons[2] = leptons.tau > 0;
00225 }
00226 }
00227 }
00228
00229
00230 std::vector<Variable::Value> values;
00231 values.push_back(Variable::Value(MVATrainer::kTargetId, 0));
00232 values.push_back(Variable::Value(kJetPt, 0));
00233 values.push_back(Variable::Value(kJetEta, 0));
00234
00235
00236 for(JetInfoMap::const_iterator iter = jetInfos.begin();
00237 iter != jetInfos.end(); iter++) {
00238 edm::RefToBase<Jet> jet = iter->first;
00239 const JetInfo &info = iter->second;
00240
00241
00242 if (jet->pt() < minPt ||
00243 std::abs(jet->eta()) < minEta ||
00244 std::abs(jet->eta()) > maxEta)
00245 continue;
00246
00247
00248 if (isIgnoreFlavour(info))
00249 continue;
00250
00251
00252 bool target = isSignalFlavour(info);
00253
00254
00255 std::vector<const BaseTagInfo*> tagInfoPtrs(nTagInfos);
00256 for(unsigned int i = 0; i < nTagInfos; i++) {
00257 if (info.tagInfos[i] < 0)
00258 continue;
00259
00260 tagInfoPtrs[i] =
00261 &tagInfoHandles[i]->at(info.tagInfos[i]);
00262 }
00263 JetTagComputer::TagInfoHelper helper(tagInfoPtrs);
00264
00265 TaggingVariableList variables =
00266 computer->taggingVariables(helper);
00267
00268
00269 int index = 0;
00270 if (categorySelector.get()) {
00271 index = categorySelector->findCategory(variables);
00272 if (index < 0)
00273 continue;
00274 }
00275
00276 GenericMVAComputer *mvaComputer =
00277 computerCache->getComputer(index);
00278 if (!mvaComputer)
00279 continue;
00280
00281
00282 values.resize(3 + variables.size());
00283 std::vector<Variable::Value>::iterator insert = values.begin();
00284
00285 (insert++)->setValue(target);
00286 (insert++)->setValue(jet->pt());
00287 (insert++)->setValue(jet->eta());
00288 std::copy(mvaComputer->iterator(variables.begin()),
00289 mvaComputer->iterator(variables.end()), insert);
00290
00291 static_cast<MVAComputer*>(mvaComputer)->eval(values);
00292 }
00293 }