CMS 3D CMS Logo

DefaultTrackMVAClassifier.cc
Go to the documentation of this file.
2 
6 
10 
11 #include <limits>
12 
13 #include "TFile.h"
14 
15 namespace {
16 
17  template <bool PROMPT>
18  struct mva {
20  : forestLabel_(cfg.getParameter<std::string>("GBRForestLabel")),
21  dbFileName_(cfg.getParameter<std::string>("GBRForestFileName")),
22  useForestFromDB_((!forestLabel_.empty()) && dbFileName_.empty()) {
23  if (useForestFromDB_) {
24  forestToken_ = iC.esConsumes(edm::ESInputTag("", forestLabel_));
25  }
26  }
27 
28  void beginStream() {
29  if (!dbFileName_.empty()) {
30  TFile gbrfile(dbFileName_.c_str());
31  forestFromFile_.reset((GBRForest *)gbrfile.Get(forestLabel_.c_str()));
32  }
33  }
34 
35  void initEvent(const edm::EventSetup &es) {
36  forest_ = forestFromFile_.get();
37  if (useForestFromDB_) {
38  forest_ = &es.getData(forestToken_);
39  }
40  }
41 
42  float operator()(reco::Track const &trk,
43  reco::BeamSpot const &beamSpot,
44  reco::VertexCollection const &vertices) const {
45  auto tmva_pt_ = trk.pt();
46  auto tmva_ndof_ = trk.ndof();
47  auto tmva_nlayers_ = trk.hitPattern().trackerLayersWithMeasurement();
48  auto tmva_nlayers3D_ =
51  float chi2n = trk.normalizedChi2();
52  float chi2n_no1Dmod = chi2n;
53 
54  int count1dhits = 0;
55  for (auto ith = trk.recHitsBegin(); ith != trk.recHitsEnd(); ++ith) {
56  const auto &hit = *(*ith);
57  if (hit.dimension() == 1)
58  ++count1dhits;
59  }
60 
61  if (count1dhits > 0) {
62  float chi2 = trk.chi2();
63  float ndof = trk.ndof();
64  chi2n = (chi2 + count1dhits) / float(ndof + count1dhits);
65  }
66  auto tmva_chi2n_ = chi2n;
67  auto tmva_chi2n_no1dmod_ = chi2n_no1Dmod;
68  auto tmva_eta_ = trk.eta();
69  auto tmva_relpterr_ = float(trk.ptError()) / std::max(float(trk.pt()), 0.000001f);
70  auto tmva_nhits_ = trk.numberOfValidHits();
73  auto tmva_minlost_ = std::min(lostIn, lostOut);
74  auto tmva_lostmidfrac_ = static_cast<float>(trk.numberOfLostHits()) /
75  static_cast<float>(trk.numberOfValidHits() + trk.numberOfLostHits());
76 
77  float gbrVals_[PROMPT ? 16 : 12];
78  gbrVals_[0] = tmva_pt_;
79  gbrVals_[1] = tmva_lostmidfrac_;
80  gbrVals_[2] = tmva_minlost_;
81  gbrVals_[3] = tmva_nhits_;
82  gbrVals_[4] = tmva_relpterr_;
83  gbrVals_[5] = tmva_eta_;
84  gbrVals_[6] = tmva_chi2n_no1dmod_;
85  gbrVals_[7] = tmva_chi2n_;
86  gbrVals_[8] = tmva_nlayerslost_;
87  gbrVals_[9] = tmva_nlayers3D_;
88  gbrVals_[10] = tmva_nlayers_;
89  gbrVals_[11] = tmva_ndof_;
90 
91  if (PROMPT) {
92  auto tmva_absd0_ = std::abs(trk.dxy(beamSpot.position()));
93  auto tmva_absdz_ = std::abs(trk.dz(beamSpot.position()));
94  Point bestVertex = getBestVertex(trk, vertices);
95  auto tmva_absd0PV_ = std::abs(trk.dxy(bestVertex));
96  auto tmva_absdzPV_ = std::abs(trk.dz(bestVertex));
97 
98  gbrVals_[12] = tmva_absd0PV_;
99  gbrVals_[13] = tmva_absdzPV_;
100  gbrVals_[14] = tmva_absdz_;
101  gbrVals_[15] = tmva_absd0_;
102  }
103 
104  return forest_->GetClassifier(gbrVals_);
105  }
106 
107  static const char *name();
108 
110  desc.add<std::string>("GBRForestLabel", std::string());
111  desc.add<std::string>("GBRForestFileName", std::string());
112  }
113 
114  std::unique_ptr<GBRForest> forestFromFile_;
115  const GBRForest *forest_ = nullptr; // owned by somebody else
116  const std::string forestLabel_;
117  const std::string dbFileName_;
118  const bool useForestFromDB_;
120  };
121 
122  using TrackMVAClassifierDetached = TrackMVAClassifier<mva<false>>;
123  using TrackMVAClassifierPrompt = TrackMVAClassifier<mva<true>>;
124  template <>
125  const char *mva<false>::name() {
126  return "TrackMVAClassifierDetached";
127  }
128  template <>
129  const char *mva<true>::name() {
130  return "TrackMVAClassifierPrompt";
131  }
132 
133 } // namespace
134 
137 
138 DEFINE_FWK_MODULE(TrackMVAClassifierDetached);
139 DEFINE_FWK_MODULE(TrackMVAClassifierPrompt);
int numberOfLostHits(HitCategory category) const
Definition: HitPattern.h:891
T const & getData(const ESGetToken< T, R > &iToken) const noexcept(false)
Definition: EventSetup.h:119
double ptError() const
error on Pt (set to 1000 TeV if charge==0 for safety)
Definition: TrackBase.h:754
unsigned short numberOfValidHits() const
number of valid hits found
Definition: TrackBase.h:798
int trackerLayersWithMeasurement() const
Definition: HitPattern.cc:534
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
double pt() const
track transverse momentum
Definition: TrackBase.h:637
double ndof() const
number of degrees of freedom of the fit
Definition: TrackBase.h:590
int numberOfValidStripLayersWithMonoAndStereo(uint16_t stripdet, uint16_t layer) const
Definition: HitPattern.cc:369
double dz() const
dz parameter (= dsz/cos(lambda)). This is the track z0 w.r.t (0,0,0) only if the refPoint is close to...
Definition: TrackBase.h:622
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
trackingRecHit_iterator recHitsEnd() const
Iterator to last hit on the track.
Definition: Track.h:91
double f[11][100]
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
unsigned short numberOfLostHits() const
number of cases where track crossed a layer without getting a hit.
Definition: TrackBase.h:801
trackingRecHit_iterator recHitsBegin() const
Iterator to first hit on the track.
Definition: Track.h:88
double eta() const
pseudorapidity of momentum vector
Definition: TrackBase.h:652
int trackerLayersWithoutMeasurement(HitCategory category) const
Definition: HitPattern.cc:553
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
const HitPattern & hitPattern() const
Access the hit pattern, indicating in which Tracker layers the track has hits.
Definition: TrackBase.h:504
double chi2() const
chi-squared of the fit
Definition: TrackBase.h:587
double normalizedChi2() const
chi-squared divided by n.d.o.f. (or chi-squared * 1e6 if n.d.o.f. is zero)
Definition: TrackBase.h:593
Structure Point Contains parameters of Gaussian fits to DMRs.
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:8
int pixelLayersWithMeasurement() const
Definition: HitPattern.cc:513
double dxy() const
dxy parameter. (This is the transverse impact parameter w.r.t. to (0,0,0) ONLY if refPoint is close t...
Definition: TrackBase.h:608