00001 #include <functional>
00002 #include <algorithm>
00003 #include <iostream>
00004 #include <fstream>
00005 #include <vector>
00006 #include <memory>
00007 #include <cmath>
00008 #include <map>
00009
00010 #include <boost/shared_ptr.hpp>
00011
00012 #include <TRandom.h>
00013 #include <TString.h>
00014 #include <TFile.h>
00015 #include <TTree.h>
00016 #include <TBranch.h>
00017 #include <TLeaf.h>
00018 #include <TList.h>
00019 #include <TKey.h>
00020
00021 #include "FWCore/Framework/interface/MakerMacros.h"
00022 #include "FWCore/Utilities/interface/Exception.h"
00023 #include "FWCore/ParameterSet/interface/ParameterSet.h"
00024 #include "FWCore/Framework/interface/Event.h"
00025 #include "FWCore/Framework/interface/EventSetup.h"
00026 #include "FWCore/Framework/interface/Run.h"
00027 #include "FWCore/Framework/interface/ESHandle.h"
00028 #include "FWCore/Framework/interface/EDAnalyzer.h"
00029
00030 #include "SimDataFormats/JetMatching/interface/JetFlavourMatching.h"
00031
00032 #include "DataFormats/Common/interface/Ref.h"
00033 #include "DataFormats/BTauReco/interface/JetTagInfo.h"
00034 #include "DataFormats/BTauReco/interface/TaggingVariable.h"
00035
00036 #include "CondFormats/PhysicsToolsObjects/interface/MVAComputer.h"
00037 #include "CondFormats/DataRecord/interface/BTauGenericMVAJetTagComputerRcd.h"
00038
00039 #include "PhysicsTools/MVATrainer/interface/MVATrainer.h"
00040
00041 #include "RecoBTau/JetTagComputer/interface/JetTagComputerRecord.h"
00042 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputer.h"
00043 #include "RecoBTau/JetTagComputer/interface/GenericMVAComputerCache.h"
00044 #include "RecoBTau/JetTagComputer/interface/TagInfoMVACategorySelector.h"
00045
00046 #include "EventProgress.h"
00047
00048 using namespace reco;
00049 using namespace PhysicsTools;
00050
00051 class Fit {
00052 public:
00053 Fit() : isFixed(true), fixedValue(0) {}
00054 Fit(double value) : isFixed(true), fixedValue(value) {}
00055 Fit(const std::string &fileName) : isFixed(false)
00056 {
00057 std::ifstream f(fileName.c_str());
00058 for(int i = 0; i < 7; i++)
00059 for(int j = 0; j < 6; j++)
00060 f >> params[i][j];
00061 }
00062
00063 operator bool() const { return !isFixed || fixedValue > 0.0; }
00064
00065 double operator () (double pt, double eta, bool isRev = false) const
00066 {
00067 if (isFixed)
00068 return fixedValue;
00069
00070 double x = std::min(std::max(-1.0, eta / 2.5), 1.0);
00071 double y = std::min(std::max(0.0, (std::log(pt + 50.0) - 4.0943445622221004) * 40.0 / 2.1 + 0.5), 36.0);
00072
00073 double facs[7];
00074 for(int i = 0; i < 7; i++) {
00075 const double *v = params[i];
00076 facs[i] = v[0] + y * (v[1] + y * (v[2] + y * (v[3] + y * (v[4] + y * v[5]))));
00077 }
00078
00079 double xs[6];
00080 xs[0] = x * x;
00081 xs[1] = xs[0] * xs[0];
00082 xs[2] = xs[1] * xs[0];
00083 xs[3] = xs[1] * xs[1];
00084 xs[4] = xs[2] * xs[1];
00085 xs[5] = xs[2] * xs[2];
00086
00087 double val =
00088 facs[0] +
00089 facs[1] * (2 * xs[0] - 1) +
00090 facs[2] * (8 * xs[1] - 8 * xs[0] + 1) +
00091 facs[3] * (32 * xs[2] - 48 * xs[1] + 18 * xs[0] - 1) +
00092 facs[4] * (128 * xs[3] - 256 * xs[2] + 160 * xs[1] - 32 * xs[0] + 1) +
00093 facs[5] * (512 * xs[4] - 1280 * xs[3] + 1120 * xs[2] - 400 * xs[1] + 50 * xs[0] - 1) +
00094 facs[6] * (2048 * xs[5] - 6144 * xs[4] + 6912 * xs[3] - 3584 * xs[2] + 840 * xs[1] - 72 * xs[0] + 1);
00095 if (isRev)
00096 return 1.0 / val;
00097 else
00098 return val;
00099 }
00100
00101 private:
00102 bool isFixed;
00103 double fixedValue;
00104 double params[7][6];
00105 };
00106
00107 class Var {
00108 public:
00109 Var(char type, TTree *tree, const char *name) :
00110 type(type), var(getTaggingVariableName(name))
00111 {
00112 switch(type) {
00113 case 'D':
00114 tree->SetBranchAddress(name, &D);
00115 break;
00116 case 'I':
00117 tree->SetBranchAddress(name, &I);
00118 break;
00119 case 'd':
00120 indirect = &d;
00121 tree->SetBranchAddress(name, &indirect);
00122 break;
00123 case 'i':
00124 indirect = &i;
00125 tree->SetBranchAddress(name, &indirect);
00126 break;
00127 }
00128 }
00129
00130 void fill(TaggingVariableList &list)
00131 {
00132 switch(type) {
00133 case 'D':
00134 list.insert(var, D, true);
00135 break;
00136 case 'I':
00137 list.insert(var, I, true);
00138 break;
00139 case 'd':
00140 for(std::vector<double>::const_iterator p = d.begin();
00141 p != d.end(); p++)
00142 list.insert(var, *p, true);
00143 break;
00144 case 'i':
00145 for(std::vector<int>::const_iterator p = i.begin();
00146 p != i.end(); p++)
00147 list.insert(var, *p, true);
00148 break;
00149 }
00150 }
00151
00152 static bool order(const boost::shared_ptr<Var> &a,
00153 const boost::shared_ptr<Var> &b)
00154 { return a->var < b->var; }
00155
00156 private:
00157 char type;
00158 TaggingVariableName var;
00159 double D;
00160 int I;
00161 std::vector<double> d;
00162 std::vector<int> i;
00163 void *indirect;
00164 };
00165
00166 class JetTagMVATreeTrainer : public edm::EDAnalyzer {
00167 public:
00168 explicit JetTagMVATreeTrainer(const edm::ParameterSet ¶ms);
00169 ~JetTagMVATreeTrainer();
00170
00171 virtual void beginRun(const edm::Run &run,
00172 const edm::EventSetup &es);
00173
00174 virtual void analyze(const edm::Event &event,
00175 const edm::EventSetup &es);
00176
00177 protected:
00178 bool isSignalFlavour(int flavour) const;
00179 bool isIgnoreFlavour(int flavour) const;
00180
00181 std::auto_ptr<TagInfoMVACategorySelector> categorySelector;
00182 std::auto_ptr<GenericMVAComputerCache> computerCache;
00183
00184 double minPt;
00185 double minEta;
00186 double maxEta;
00187 double factor;
00188 double bound;
00189 double signalFactor;
00190
00191 private:
00192 std::vector<int> signalFlavours;
00193 std::vector<int> ignoreFlavours;
00194 Fit weights;
00195 std::vector<Fit> bias;
00196 double limiter;
00197 int maxEvents;
00198 TRandom rand;
00199
00200 std::vector<std::string> fileNames;
00201 };
00202
00203 JetTagMVATreeTrainer::JetTagMVATreeTrainer(const edm::ParameterSet ¶ms) :
00204 minPt(params.getParameter<double>("minimumTransverseMomentum")),
00205 minEta(params.getParameter<double>("minimumPseudoRapidity")),
00206 maxEta(params.getParameter<double>("maximumPseudoRapidity")),
00207 factor(params.getParameter<double>("factor")),
00208 bound(params.getParameter<double>("bound")),
00209 signalFactor(params.getUntrackedParameter<double>("signalFactor", 1.0)),
00210 signalFlavours(params.getParameter<std::vector<int> >("signalFlavours")),
00211 ignoreFlavours(params.getParameter<std::vector<int> >("ignoreFlavours")),
00212 limiter(params.getUntrackedParameter<double>("weightThreshold", 0.0)),
00213 maxEvents(params.getUntrackedParameter<int>("maxEvents", -1)),
00214 fileNames(params.getParameter<std::vector<std::string> >("fileNames"))
00215 {
00216 std::sort(signalFlavours.begin(), signalFlavours.end());
00217 std::sort(ignoreFlavours.begin(), ignoreFlavours.end());
00218
00219 std::vector<std::string> calibrationLabels;
00220 if (params.getParameter<bool>("useCategories")) {
00221 categorySelector = std::auto_ptr<TagInfoMVACategorySelector>(
00222 new TagInfoMVACategorySelector(params));
00223
00224 calibrationLabels = categorySelector->getCategoryLabels();
00225 } else {
00226 std::string calibrationRecord =
00227 params.getParameter<std::string>("calibrationRecord");
00228
00229 calibrationLabels.push_back(calibrationRecord);
00230 }
00231
00232 computerCache = std::auto_ptr<GenericMVAComputerCache>(
00233 new GenericMVAComputerCache(calibrationLabels));
00234
00235 weights = Fit(params.getParameter<std::string>("weightFile"));
00236
00237 std::vector<std::string> biasFiles = params.getParameter<std::vector<std::string> >("biasFiles");
00238 for(std::vector<std::string>::const_iterator iter = biasFiles.begin();
00239 iter != biasFiles.end(); iter++) {
00240 if (*iter == "*")
00241 bias.push_back(Fit(1.0));
00242 else if (*iter == "-")
00243 bias.push_back(Fit(0.0));
00244 else
00245 bias.push_back(Fit(*iter));
00246 }
00247 }
00248
00249 JetTagMVATreeTrainer::~JetTagMVATreeTrainer()
00250 {
00251 }
00252
00253 bool JetTagMVATreeTrainer::isSignalFlavour(int flavour) const
00254 {
00255 std::vector<int>::const_iterator pos =
00256 std::lower_bound(signalFlavours.begin(), signalFlavours.end(),
00257 flavour);
00258
00259 return pos != signalFlavours.end() && *pos == flavour;
00260 }
00261
00262 bool JetTagMVATreeTrainer::isIgnoreFlavour(int flavour) const
00263 {
00264 std::vector<int>::const_iterator pos =
00265 std::lower_bound(ignoreFlavours.begin(), ignoreFlavours.end(),
00266 flavour);
00267
00268 return pos != ignoreFlavours.end() && *pos == flavour;
00269 }
00270
00271 void JetTagMVATreeTrainer::beginRun(const edm::Run& run,
00272 const edm::EventSetup& es)
00273 {
00274 rand.SetSeed(65539);
00275 }
00276
00277 void JetTagMVATreeTrainer::analyze(const edm::Event& event,
00278 const edm::EventSetup& es)
00279 {
00280
00281 edm::ESHandle<Calibration::MVAComputerContainer> calibHandle;
00282 es.get<BTauGenericMVAJetTagComputerRcd>().get("trainer", calibHandle);
00283 const Calibration::MVAComputerContainer *calib = calibHandle.product();
00284
00285
00286 computerCache->update(calib);
00287 if (computerCache->isEmpty())
00288 return;
00289
00290
00291 std::vector<Variable::Value> values;
00292 values.push_back(Variable::Value(MVATrainer::kTargetId, 0));
00293 values.push_back(Variable::Value(MVATrainer::kWeightId, 0));
00294
00295 int nEvents = 0;
00296 for(std::vector<std::string>::const_iterator fName = fileNames.begin();
00297 fName != fileNames.end(); fName++) {
00298 if (maxEvents >= 0 && nEvents >= maxEvents)
00299 break;
00300
00301 std::auto_ptr<TFile> file(TFile::Open(fName->c_str()));
00302 if (!file.get())
00303 continue;
00304 std::cout << "Opened " << *fName << std::endl;
00305
00306 TIter next(file->GetListOfKeys());
00307 TObject *obj;
00308 while((obj = next())) {
00309 if (maxEvents >= 0 && nEvents >= maxEvents)
00310 break;
00311
00312 TTree *tree = dynamic_cast<TTree*>(file->Get(((TKey*)obj)->GetName()));
00313 if (!tree)
00314 continue;
00315 std::cout << "Tree " << tree->GetName() << std::endl;
00316
00317 int flavour;
00318 tree->SetBranchAddress("flavour", &flavour);
00319
00320 std::vector< boost::shared_ptr<Var> > vars;
00321
00322 TIter branchIter(tree->GetListOfBranches());
00323 while((obj = branchIter())) {
00324 TBranch *branch = dynamic_cast<TBranch*>(obj);
00325 if (!branch)
00326 continue;
00327
00328 TString name = branch->GetName();
00329 TLeaf *leaf = dynamic_cast<TLeaf*>(
00330 branch->GetLeaf(name));
00331 if (!leaf)
00332 continue;
00333
00334 TString typeName = leaf->GetTypeName();
00335 char typeId;
00336 if (typeName == "Double_t")
00337 typeId = 'D';
00338 else if (typeName == "Int_t")
00339 typeId = 'I';
00340 else if (typeName == "vector<double>")
00341 typeId = 'd';
00342 else if (typeName == "vector<int>")
00343 typeId = 'i';
00344 else
00345 continue;
00346
00347 if (getTaggingVariableName((const char *)name) ==
00348 btau::lastTaggingVariable)
00349 continue;
00350 vars.push_back(boost::shared_ptr<Var>(
00351 new Var(typeId, tree, name)));
00352 }
00353 std::sort(vars.begin(), vars.end(), &Var::order);
00354
00355 Long64_t entries = tree->GetEntries();
00356 std::cout << "Entries " << entries << std::endl;
00357 EventProgress progress(entries);
00358 for(Long64_t entry = 0; entry < entries; entry++) {
00359 if (maxEvents >= 0 && nEvents >= maxEvents)
00360 break;
00361
00362 progress.update(entry);
00363 tree->GetEntry(entry);
00364
00365 TaggingVariableList variables;
00366 for(std::vector< boost::shared_ptr<Var> >::const_iterator iter = vars.begin();
00367 iter != vars.end(); iter++)
00368 (*iter)->fill(variables);
00369 variables.finalize();
00370
00371 double jetPt = variables[btau::jetPt];
00372 double jetEta = variables[btau::jetEta];
00373
00374
00375 if (jetPt < minPt ||
00376 std::abs(jetEta) < minEta ||
00377 std::abs(jetEta) > maxEta)
00378 continue;
00379
00380
00381 if (isIgnoreFlavour(flavour))
00382 continue;
00383
00384
00385 bool target = isSignalFlavour(flavour);
00386
00387
00388 int index = 0;
00389 if (categorySelector.get()) {
00390 index = categorySelector->findCategory(variables);
00391 if (index < 0)
00392 continue;
00393 }
00394
00395 GenericMVAComputer *mvaComputer =
00396 computerCache->getComputer(index);
00397 if (!mvaComputer)
00398 continue;
00399
00400 int idx = 0;
00401 if (flavour == 4)
00402 idx = 1;
00403 else if (flavour == 5 || flavour == 7)
00404 idx = 2;
00405 double pBias[3];
00406 for(int i = 0; i < 3; i++)
00407 pBias[i] = bias[i](jetPt, jetEta, i < 2);
00408 double weight;
00409 if (bias[0] && bias[1])
00410 weight = (idx == 0) ? 0.75 :
00411 (idx == 1) ? 0.25 : 1.0;
00412 else
00413 weight = 1.0;
00414 weight *= jetPt / 50.0 + 1.0;
00415 weight /= 1.0 + std::exp((jetPt - 600.0) / 150.0);
00416 weight *= (1.0 - 0.01 * std::exp(0.5 * jetEta * jetEta));
00417 weight /= weights(jetPt, jetEta);
00418 weight *= pBias[0] + pBias[1] + pBias[2];
00419 weight /= pBias[idx];
00420
00421 weight *= factor;
00422 if (weight > bound)
00423 weight = bound;
00424
00425 if (idx == 2)
00426 weight *= signalFactor;
00427
00428 if (weight < limiter) {
00429 if (rand.Uniform(limiter) > weight)
00430 continue;
00431 weight = limiter;
00432 }
00433
00434
00435 values.resize(2 + variables.size());
00436
00437 std::vector<Variable::Value>::iterator insert = values.begin();
00438 (insert++)->setValue(target);
00439 (insert++)->setValue(weight);
00440
00441 std::copy(mvaComputer->iterator(variables.begin()),
00442 mvaComputer->iterator(variables.end()), insert);
00443
00444 static_cast<MVAComputer*>(mvaComputer)->eval(values);
00445
00446 nEvents++;
00447 }
00448 }
00449 }
00450 }
00451
00452
00453 DEFINE_FWK_MODULE(JetTagMVATreeTrainer);