CMS 3D CMS Logo

DefaultTrackMVAClassifier.cc
Go to the documentation of this file.
2 
6 
9 #include <limits>
10 
11 #include "getBestVertex.h"
12 
13 #include "TFile.h"
14 
15 namespace {
16 
17 template<bool PROMPT>
18 struct mva {
19  mva(const edm::ParameterSet &cfg):
20  forestLabel_ ( cfg.getParameter<std::string>("GBRForestLabel") ),
21  dbFileName_ ( cfg.getParameter<std::string>("GBRForestFileName") ),
22  useForestFromDB_( (!forestLabel_.empty()) & dbFileName_.empty())
23  {}
24 
25  void beginStream() {
26  if(!dbFileName_.empty()){
27  TFile gbrfile(dbFileName_.c_str());
28  forestFromFile_.reset((GBRForest*)gbrfile.Get(forestLabel_.c_str()));
29  }
30  }
31 
32  void initEvent(const edm::EventSetup& es) {
33  forest_ = forestFromFile_.get();
34  if(useForestFromDB_){
35  edm::ESHandle<GBRForest> forestHandle;
36  es.get<GBRWrapperRcd>().get(forestLabel_,forestHandle);
37  forest_ = forestHandle.product();
38  }
39  }
40 
41  float operator()(reco::Track const & trk,
42  reco::BeamSpot const & beamSpot,
43  reco::VertexCollection const & vertices) const {
44 
45  auto tmva_pt_ = trk.pt();
46  auto tmva_ndof_ = trk.ndof();
47  auto tmva_nlayers_ = trk.hitPattern().trackerLayersWithMeasurement();
48  auto tmva_nlayers3D_ = trk.hitPattern().pixelLayersWithMeasurement()
51  float chi2n = trk.normalizedChi2();
52  float chi2n_no1Dmod = chi2n;
53 
54  int count1dhits = 0;
55  for (auto ith =trk.recHitsBegin(); ith!=trk.recHitsEnd(); ++ith) {
56  const auto & hit = *(*ith);
57  if (hit.dimension()==1) ++count1dhits;
58  }
59 
60  if (count1dhits > 0) {
61  float chi2 = trk.chi2();
62  float ndof = trk.ndof();
63  chi2n = (chi2+count1dhits)/float(ndof+count1dhits);
64  }
65  auto tmva_chi2n_ = chi2n;
66  auto tmva_chi2n_no1dmod_ = chi2n_no1Dmod;
67  auto tmva_eta_ = trk.eta();
68  auto tmva_relpterr_ = float(trk.ptError())/std::max(float(trk.pt()),0.000001f);
69  auto tmva_nhits_ = trk.numberOfValidHits();
72  auto tmva_minlost_ = std::min(lostIn,lostOut);
73  auto tmva_lostmidfrac_ = static_cast<float>(trk.numberOfLostHits()) / static_cast<float>(trk.numberOfValidHits() + trk.numberOfLostHits());
74 
75  float gbrVals_[PROMPT ? 16 : 12];
76  gbrVals_[0] = tmva_pt_;
77  gbrVals_[1] = tmva_lostmidfrac_;
78  gbrVals_[2] = tmva_minlost_;
79  gbrVals_[3] = tmva_nhits_;
80  gbrVals_[4] = tmva_relpterr_;
81  gbrVals_[5] = tmva_eta_;
82  gbrVals_[6] = tmva_chi2n_no1dmod_;
83  gbrVals_[7] = tmva_chi2n_;
84  gbrVals_[8] = tmva_nlayerslost_;
85  gbrVals_[9] = tmva_nlayers3D_;
86  gbrVals_[10] = tmva_nlayers_;
87  gbrVals_[11] = tmva_ndof_;
88 
89  if (PROMPT) {
90  auto tmva_absd0_ = std::abs(trk.dxy(beamSpot.position()));
91  auto tmva_absdz_ = std::abs(trk.dz(beamSpot.position()));
92  Point bestVertex = getBestVertex(trk,vertices);
93  auto tmva_absd0PV_ = std::abs(trk.dxy(bestVertex));
94  auto tmva_absdzPV_ = std::abs(trk.dz(bestVertex));
95 
96  gbrVals_[12] = tmva_absd0PV_;
97  gbrVals_[13] = tmva_absdzPV_;
98  gbrVals_[14] = tmva_absdz_;
99  gbrVals_[15] = tmva_absd0_;
100  }
101 
102 
103 
104  return forest_->GetClassifier(gbrVals_);
105 
106  }
107 
108  static const char * name();
109 
110  static void fillDescriptions(edm::ParameterSetDescription & desc) {
111  desc.add<std::string>("GBRForestLabel",std::string());
112  desc.add<std::string>("GBRForestFileName",std::string());
113  }
114 
115  std::unique_ptr<GBRForest> forestFromFile_;
116  const GBRForest *forest_ = nullptr; // owned by somebody else
117  const std::string forestLabel_;
118  const std::string dbFileName_;
119  const bool useForestFromDB_;
120 };
121 
122  using TrackMVAClassifierDetached = TrackMVAClassifier<mva<false>>;
123  using TrackMVAClassifierPrompt = TrackMVAClassifier<mva<true>>;
124  template<>
125  const char * mva<false>::name() { return "TrackMVAClassifierDetached";}
126  template<>
127  const char * mva<true>::name() { return "TrackMVAClassifierPrompt";}
128 
129 }
130 
133 
134 DEFINE_FWK_MODULE(TrackMVAClassifierDetached);
135 DEFINE_FWK_MODULE(TrackMVAClassifierPrompt);
double normalizedChi2() const
chi-squared divided by n.d.o.f. (or chi-squared * 1e6 if n.d.o.f. is zero)
Definition: TrackBase.h:594
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:17
unsigned short numberOfLostHits() const
number of cases where track crossed a layer without getting a hit.
Definition: TrackBase.h:895
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
int pixelLayersWithMeasurement() const
Definition: HitPattern.cc:538
int trackerLayersWithMeasurement() const
Definition: HitPattern.cc:557
double eta() const
pseudorapidity of momentum vector
Definition: TrackBase.h:684
int numberOfValidStripLayersWithMonoAndStereo(uint16_t stripdet, uint16_t layer) const
Definition: HitPattern.cc:379
double chi2() const
chi-squared of the fit
Definition: TrackBase.h:582
double ndof() const
number of degrees of freedom of the fit
Definition: TrackBase.h:588
double pt() const
track transverse momentum
Definition: TrackBase.h:654
double ptError() const
error on Pt (set to 1000 TeV if charge==0 for safety)
Definition: TrackBase.h:808
math::XYZPoint Point
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
double f[11][100]
unsigned short numberOfValidHits() const
number of valid hits found
Definition: TrackBase.h:889
T min(T a, T b)
Definition: MathUtil.h:58
trackingRecHit_iterator recHitsBegin() const
Iterator to first hit on the track.
Definition: Track.h:106
ParameterDescriptionBase * add(U const &iLabel, T const &value)
double dz() const
dz parameter (= dsz/cos(lambda)). This is the track z0 w.r.t (0,0,0) only if the refPoint is close to...
Definition: TrackBase.h:642
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
const HitPattern & hitPattern() const
Access the hit pattern, indicating in which Tracker layers the track has hits.
Definition: TrackBase.h:479
int numberOfLostHits(HitCategory category) const
Definition: HitPattern.h:982
int trackerLayersWithoutMeasurement(HitCategory category) const
Definition: HitPattern.cc:574
T get() const
Definition: EventSetup.h:68
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:9
const Point & position() const
position
Definition: BeamSpot.h:62
double dxy() const
dxy parameter. (This is the transverse impact parameter w.r.t. to (0,0,0) ONLY if refPoint is close t...
Definition: TrackBase.h:624
T const * product() const
Definition: ESHandle.h:84
trackingRecHit_iterator recHitsEnd() const
Iterator to last hit on the track.
Definition: Track.h:111