00001 #include <functional>
00002 #include <algorithm>
00003 #include <iterator>
00004 #include <iostream>
00005 #include <vector>
00006 #include <memory>
00007 #include <cmath>
00008 #include <map>
00009
00010 #include <boost/shared_ptr.hpp>
00011
00012 #include <TDirectory.h>
00013 #include <TTree.h>
00014 #include <TFile.h>
00015
00016 #include "FWCore/Utilities/interface/Exception.h"
00017 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00018 #include "FWCore/Utilities/interface/InputTag.h"
00019 #include "FWCore/Framework/interface/Event.h"
00020 #include "FWCore/Framework/interface/EventSetup.h"
00021 #include "FWCore/Framework/interface/ESHandle.h"
00022 #include "FWCore/Framework/interface/EDAnalyzer.h"
00023
00024 #include "SimDataFormats/JetMatching/interface/JetFlavourMatching.h"
00025
00026 #include "DataFormats/Common/interface/View.h"
00027 #include "DataFormats/BTauReco/interface/JetTagInfo.h"
00028 #include "DataFormats/BTauReco/interface/TaggingVariable.h"
00029
00030 #include "CondFormats/PhysicsToolsObjects/interface/MVAComputer.h"
00031
00032 #include "RecoBTau/JetTagComputer/interface/JetTagComputer.h"
00033 #include "RecoBTau/JetTagComputer/interface/JetTagComputerRecord.h"
00034 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputer.h"
00035 #include "RecoBTau/JetTagComputer/interface/GenericMVAJetTagComputer.h"
00036 #include "RecoBTau/JetTagComputer/interface/TagInfoMVACategorySelector.h"
00037
00038 using namespace reco;
00039 using namespace PhysicsTools;
00040
00041 namespace {
00042
00043 class ROOTContextSentinel {
00044 public:
00045 ROOTContextSentinel() : dir(gDirectory), file(gFile) {}
00046 ~ROOTContextSentinel() { gDirectory = dir; gFile = file; }
00047
00048 private:
00049 TDirectory *dir;
00050 TFile *file;
00051 };
00052
00053 }
00054
00055 static const AtomicId kJetPt(TaggingVariableTokens[btau::jetPt]);
00056 static const AtomicId kJetEta(TaggingVariableTokens[btau::jetEta]);
00057
00058 class JetTagMVAExtractor : public edm::EDAnalyzer {
00059 public:
00060 explicit JetTagMVAExtractor(const edm::ParameterSet ¶ms);
00061 ~JetTagMVAExtractor();
00062
00063 virtual void analyze(const edm::Event &event,
00064 const edm::EventSetup &es);
00065
00066 private:
00067 typedef std::vector<Variable::Value> Values;
00068
00069 struct Index {
00070 inline Index(int flavour, int index) :
00071 index(index), flavour(flavour) {}
00072
00073 inline bool operator == (const Index &rhs) const
00074 { return index == rhs.index && flavour == rhs.flavour; }
00075
00076 inline bool operator < (const Index &rhs) const
00077 { return index == rhs.index ? (flavour < rhs.flavour) : (index < rhs.index); }
00078
00079 int index;
00080 int flavour;
00081 };
00082
00083 struct Tree {
00084 Tree(const JetTagMVAExtractor &main, Index index);
00085 ~Tree();
00086
00087 struct Value {
00088 Value() : type(0), multiple(false) {}
00089 Value(char type, bool multiple) : type(type), multiple(multiple) {}
00090
00091 void clear() { sInt = -999; sDouble = -999.0; vInt.clear(); vDouble.clear(); }
00092 void set(double value)
00093 {
00094 if (type == 'I' && multiple)
00095 vInt.push_back((int)std::floor(value + 0.5));
00096 else if (type == 'D' && multiple)
00097 vDouble.push_back(value);
00098 else if (type == 'I' && !multiple)
00099 sInt = (int)std::floor(value + 0.5);
00100 else if (type == 'D' && !multiple)
00101 sDouble = value;
00102 }
00103
00104 char type;
00105 bool multiple;
00106
00107 void *indirect;
00108 Int_t sInt;
00109 Double_t sDouble;
00110 std::vector<int> vInt;
00111 std::vector<double> vDouble;
00112 };
00113
00114 int flavour;
00115 TTree *tree;
00116 std::auto_ptr<TFile> file;
00117 std::map<AtomicId, Value> values;
00118 };
00119
00120 struct Label {
00121 Label() {}
00122 Label(const edm::ParameterSet &pset);
00123 Label(const Label &label) : variables(label.variables), label(label.label) {}
00124
00125 struct Var {
00126 Var(const std::string &name);
00127
00128 AtomicId id;
00129 char type;
00130 bool multiple;
00131 };
00132
00133 std::vector<Var> variables;
00134 std::string label;
00135 };
00136
00137 friend class Tree;
00138
00139 void setup(const JetTagComputer &computer);
00140 void process(Index index, const Values &values);
00141
00142 edm::InputTag jetFlavour;
00143 std::auto_ptr<TagInfoMVACategorySelector> categorySelector;
00144
00145 double minPt;
00146 double minEta;
00147 double maxEta;
00148
00149 bool setupDone;
00150 std::string jetTagComputer;
00151 const GenericMVAComputer mvaComputer;
00152
00153 std::map<std::string, edm::InputTag> tagInfoLabels;
00154 std::vector<edm::InputTag> tagInfos;
00155
00156 std::vector<Label> calibrationLabels;
00157
00158 std::map<Index, boost::shared_ptr<Tree> > treeMap;
00159 };
00160
00161 JetTagMVAExtractor::Tree::Tree(const JetTagMVAExtractor &main, Index index) :
00162 flavour(index.flavour)
00163 {
00164 static const char flavourMap[] = " DUSCBTB G";
00165
00166 if (index.index < 0 || index.index >= (int)main.calibrationLabels.size())
00167 return;
00168
00169 ROOTContextSentinel ctx;
00170
00171 Label label = main.calibrationLabels[index.index];
00172 if (index.flavour > 21) std::cout << index.flavour << std::endl;
00173 std::string flavour = std::string("") + flavourMap[index.flavour];
00174 file.reset(new TFile((label.label + "_" + flavour + ".root").c_str(), "RECREATE"));
00175 file->cd();
00176
00177 tree = new TTree(label.label.c_str(), (label.label + "_" + flavour).c_str());
00178
00179 tree->Branch("flavour", &this->flavour, "flavour/I");
00180
00181 for(std::vector<Label::Var>::const_iterator iter = label.variables.begin();
00182 iter != label.variables.end(); iter++) {
00183 values[iter->id] = Value(iter->type, iter->multiple);
00184 Value &value = values[iter->id];
00185 const char *name = iter->id;
00186
00187 if (iter->type == 'I' && !iter->multiple) {
00188 tree->Branch(name, &value.sInt,
00189 (std::string(name) + "/I").c_str());
00190 } else if (iter->type == 'D' && !iter->multiple) {
00191 tree->Branch(name, &value.sDouble,
00192 (std::string(name) + "/D").c_str());
00193 } else if (iter->type == 'I' && iter->multiple) {
00194 value.indirect = &value.vInt;
00195 tree->Branch(name, "std::vector<int>",
00196 &value.indirect);
00197 } else if (iter->type == 'D' && iter->multiple) {
00198 value.indirect = &value.vDouble;
00199 tree->Branch(name, "std::vector<double>",
00200 &value.indirect);
00201 }
00202 }
00203 }
00204
00205 JetTagMVAExtractor::Tree::~Tree()
00206 {
00207 if (!tree)
00208 return;
00209
00210 ROOTContextSentinel ctx;
00211
00212 file->cd();
00213 tree->Write();
00214 file->Close();
00215 }
00216
00217 JetTagMVAExtractor::Label::Label(const edm::ParameterSet &pset) :
00218 label(pset.getUntrackedParameter<std::string>("label"))
00219 {
00220 std::vector<std::string> vars =
00221 pset.getUntrackedParameter< std::vector<std::string> >("variables");
00222 std::copy(vars.begin(), vars.end(), std::back_inserter(variables));
00223 }
00224
00225 JetTagMVAExtractor::Label::Var::Var(const std::string &name) :
00226 id(name)
00227 {
00228 TaggingVariableName tag = getTaggingVariableName(name);
00229 if (tag == btau::lastTaggingVariable)
00230 throw cms::Exception("UnknownTaggingVariable")
00231 << "Unknown tagging variable " << name << std::endl;
00232
00233 multiple = ((int)tag >= (int)btau::trackMomentum &&
00234 (int)tag <= (int)btau::trackGhostTrackWeight) ||
00235 ((int)tag >= (int)btau::trackP0Par &&
00236 (int)tag <= (int)btau::algoDiscriminator);
00237
00238 type = (tag == btau::jetNTracks ||
00239 tag == btau::vertexCategory ||
00240 tag == btau::jetNSecondaryVertices ||
00241 tag == btau::vertexNTracks ||
00242 tag == btau::trackNTotalHits ||
00243 tag == btau::trackNPixelHits) ? 'I' : 'D';
00244 }
00245
00246 static const Calibration::MVAComputer *dummyCalib()
00247 {
00248 static Calibration::MVAComputer dummy;
00249 static bool init = false;
00250
00251 if (!init)
00252 dummy.inputSet.push_back(Calibration::Variable());
00253
00254 return &dummy;
00255 }
00256
00257 JetTagMVAExtractor::JetTagMVAExtractor(const edm::ParameterSet ¶ms) :
00258 jetFlavour(params.getParameter<edm::InputTag>("jetFlavourMatching")),
00259 minPt(params.getParameter<double>("minimumTransverseMomentum")),
00260 minEta(params.getParameter<double>("minimumPseudoRapidity")),
00261 maxEta(params.getParameter<double>("maximumPseudoRapidity")),
00262 setupDone(false),
00263 jetTagComputer(params.getParameter<std::string>("jetTagComputer")),
00264 mvaComputer(dummyCalib())
00265 {
00266 std::vector<std::string> labels;
00267
00268 if (params.getParameter<bool>("useCategories")) {
00269 categorySelector = std::auto_ptr<TagInfoMVACategorySelector>(
00270 new TagInfoMVACategorySelector(params));
00271
00272 labels = categorySelector->getCategoryLabels();
00273 } else {
00274 std::string calibrationRecord =
00275 params.getParameter<std::string>("calibrationRecord");
00276
00277 labels.push_back(calibrationRecord);
00278 }
00279
00280 std::vector<edm::ParameterSet> variables =
00281 params.getUntrackedParameter< std::vector<edm::ParameterSet> >("variables");
00282
00283 std::map<std::string, Label> labelMap;
00284
00285 for(std::vector<edm::ParameterSet>::const_iterator iter = variables.begin();
00286 iter != variables.end(); iter++) {
00287 Label label(*iter);
00288 if (labelMap.count(label.label))
00289 throw cms::Exception("DuplVariables")
00290 << "Duplicated label for variables "
00291 << label.label << std::endl;
00292 labelMap[label.label] = label;
00293 }
00294
00295 if (labelMap.size() != labels.size())
00296 throw cms::Exception("MismatchVariables")
00297 << "Label variables mismatch." << std::endl;
00298
00299 for(std::vector<std::string>::const_iterator iter = labels.begin();
00300 iter != labels.end(); iter++) {
00301 std::map<std::string, Label>::const_iterator pos =
00302 labelMap.find(*iter);
00303 if (pos == labelMap.end())
00304 throw cms::Exception("MismatchVariables")
00305 << "Variables definition for " << *iter
00306 << " not found." << std::endl;
00307
00308 calibrationLabels.push_back(pos->second);
00309 }
00310
00311 std::vector<std::string> inputTags =
00312 params.getParameterNamesForType<edm::InputTag>();
00313
00314 for(std::vector<std::string>::const_iterator iter = inputTags.begin();
00315 iter != inputTags.end(); iter++)
00316 tagInfoLabels[*iter] =
00317 params.getParameter<edm::InputTag>(*iter);
00318 }
00319
00320 JetTagMVAExtractor::~JetTagMVAExtractor()
00321 {
00322 }
00323
00324 void JetTagMVAExtractor::setup(const JetTagComputer &computer)
00325 {
00326 std::vector<std::string> inputLabels = computer.getInputLabels();
00327 if (inputLabels.empty())
00328 inputLabels.push_back("tagInfo");
00329
00330 for(std::vector<std::string>::const_iterator iter = inputLabels.begin();
00331 iter != inputLabels.end(); iter++) {
00332 std::map<std::string, edm::InputTag>::const_iterator pos =
00333 tagInfoLabels.find(*iter);
00334 if (pos == tagInfoLabels.end())
00335 throw cms::Exception("InputTagMissing")
00336 << "JetTagMVAExtractor is missing a TagInfo "
00337 "InputTag \"" << *iter << "\"" << std::endl;
00338
00339 tagInfos.push_back(pos->second);
00340 }
00341
00342 setupDone = true;
00343 }
00344
00345
00346 namespace {
00347 struct JetCompare :
00348 public std::binary_function<edm::RefToBase<Jet>,
00349 edm::RefToBase<Jet>, bool> {
00350 inline bool operator () (const edm::RefToBase<Jet> &j1,
00351 const edm::RefToBase<Jet> &j2) const
00352 { return j1.key() < j2.key(); }
00353 };
00354
00355 struct JetInfo {
00356 unsigned int flavour;
00357 std::vector<int> tagInfos;
00358 };
00359 }
00360
00361 void JetTagMVAExtractor::analyze(const edm::Event& event,
00362 const edm::EventSetup& es)
00363 {
00364
00365 edm::ESHandle<JetTagComputer> computerHandle;
00366 es.get<JetTagComputerRecord>().get(jetTagComputer, computerHandle);
00367 const GenericMVAJetTagComputer *computer =
00368 dynamic_cast<const GenericMVAJetTagComputer*>(
00369 computerHandle.product());
00370 if (!computer)
00371 throw cms::Exception("InvalidCast")
00372 << "JetTagComputer is not a MVAJetTagComputer "
00373 "in JetTagMVAExtractor" << std::endl;
00374
00375 computer->passEventSetup(es);
00376
00377
00378 if (!setupDone)
00379 setup(*computer);
00380
00381
00382 typedef edm::RefToBase<Jet> JetRef;
00383 typedef std::map<JetRef, JetInfo, JetCompare> JetInfoMap;
00384 JetInfoMap jetInfos;
00385
00386 std::vector< edm::Handle< edm::View<BaseTagInfo> > >
00387 tagInfoHandles(tagInfos.size());
00388 unsigned int nTagInfos = tagInfos.size();
00389 for(unsigned int i = 0; i < nTagInfos; i++) {
00390 edm::Handle< edm::View<BaseTagInfo> > &tagInfoHandle =
00391 tagInfoHandles[i];
00392 event.getByLabel(tagInfos[i], tagInfoHandle);
00393
00394 int j = 0;
00395 for(edm::View<BaseTagInfo>::const_iterator iter =
00396 tagInfoHandle->begin();
00397 iter != tagInfoHandle->end(); iter++, j++) {
00398
00399 JetInfo &jetInfo = jetInfos[iter->jet()];
00400 if (jetInfo.tagInfos.empty()) {
00401 jetInfo.flavour = 0;
00402 jetInfo.tagInfos.resize(nTagInfos, -1);
00403 }
00404
00405 jetInfo.tagInfos[i] = j;
00406 }
00407 }
00408
00409
00410 edm::Handle<JetFlavourMatchingCollection> jetFlavourHandle;
00411 event.getByLabel(jetFlavour, jetFlavourHandle);
00412
00413 for(JetFlavourMatchingCollection::const_iterator iter =
00414 jetFlavourHandle->begin();
00415 iter != jetFlavourHandle->end(); iter++) {
00416
00417 JetInfoMap::iterator pos = jetInfos.find(iter->first);
00418 if (pos != jetInfos.end())
00419 pos->second.flavour =
00420 std::abs(iter->second.getFlavour());
00421 }
00422
00423
00424 std::vector<Variable::Value> values;
00425 values.push_back(Variable::Value(kJetPt, 0));
00426 values.push_back(Variable::Value(kJetEta, 0));
00427
00428
00429 for(JetInfoMap::const_iterator iter = jetInfos.begin();
00430 iter != jetInfos.end(); iter++) {
00431 edm::RefToBase<Jet> jet = iter->first;
00432 const JetInfo &info = iter->second;
00433
00434
00435 if (jet->pt() < minPt ||
00436 std::abs(jet->eta()) < minEta ||
00437 std::abs(jet->eta()) > maxEta)
00438 continue;
00439
00440
00441 if (!info.flavour)
00442 continue;
00443
00444
00445 std::vector<const BaseTagInfo*> tagInfoPtrs(nTagInfos);
00446 for(unsigned int i = 0; i < nTagInfos; i++) {
00447 if (info.tagInfos[i] < 0)
00448 continue;
00449
00450 tagInfoPtrs[i] =
00451 &tagInfoHandles[i]->at(info.tagInfos[i]);
00452 }
00453 JetTagComputer::TagInfoHelper helper(tagInfoPtrs);
00454
00455 TaggingVariableList variables =
00456 computer->taggingVariables(helper);
00457
00458
00459 int index = 0;
00460 if (categorySelector.get()) {
00461 index = categorySelector->findCategory(variables);
00462 if (index < 0)
00463 continue;
00464 }
00465
00466
00467 values.resize(2 + variables.size());
00468 std::vector<Variable::Value>::iterator insert = values.begin();
00469
00470 (insert++)->setValue(jet->pt());
00471 (insert++)->setValue(jet->eta());
00472 std::copy(mvaComputer.iterator(variables.begin()),
00473 mvaComputer.iterator(variables.end()), insert);
00474
00475 process(Index(info.flavour, index), values);
00476 }
00477 }
00478
00479 void JetTagMVAExtractor::process(Index index, const Values &values)
00480 {
00481 if (index.flavour == 7)
00482 index.flavour = 5;
00483
00484 std::map<Index, boost::shared_ptr<Tree> >::iterator pos = treeMap.find(index);
00485 Tree *tree;
00486
00487 if (pos == treeMap.end())
00488 tree = treeMap.insert(std::make_pair(index, boost::shared_ptr<Tree>(new Tree(*this, index)))).first->second.get();
00489 else
00490 tree = pos->second.get();
00491
00492 if (!tree->tree)
00493 return;
00494
00495 for(std::map<AtomicId, Tree::Value>::iterator iter = tree->values.begin();
00496 iter != tree->values.end(); iter++)
00497 iter->second.clear();
00498
00499 for(Values::const_iterator iter = values.begin();
00500 iter != values.end(); iter++) {
00501 std::map<AtomicId, Tree::Value>::iterator pos = tree->values.find(iter->getName());
00502 if (pos == tree->values.end())
00503 throw cms::Exception("VarNotFound")
00504 << "Variable " << (const char*)iter->getName()
00505 << " not found." << std::endl;
00506
00507 pos->second.set(iter->getValue());
00508 }
00509
00510 tree->tree->Fill();
00511 }
00512
00513 #include "FWCore/Framework/interface/MakerMacros.h"
00514
00515
00516 DEFINE_FWK_MODULE(JetTagMVAExtractor);