CMS 3D CMS Logo

DefaultTrackMVAClassifier.cc
Go to the documentation of this file.
2 
6 
9 #include <limits>
10 
11 #include "getBestVertex.h"
12 
13 #include "TFile.h"
14 
15 namespace {
16 
17  template <bool PROMPT>
18  struct mva {
19  mva(const edm::ParameterSet &cfg)
20  : forestLabel_(cfg.getParameter<std::string>("GBRForestLabel")),
21  dbFileName_(cfg.getParameter<std::string>("GBRForestFileName")),
22  useForestFromDB_((!forestLabel_.empty()) & dbFileName_.empty()) {}
23 
24  void beginStream() {
25  if (!dbFileName_.empty()) {
26  TFile gbrfile(dbFileName_.c_str());
27  forestFromFile_.reset((GBRForest *)gbrfile.Get(forestLabel_.c_str()));
28  }
29  }
30 
31  void initEvent(const edm::EventSetup &es) {
32  forest_ = forestFromFile_.get();
33  if (useForestFromDB_) {
34  edm::ESHandle<GBRForest> forestHandle;
35  es.get<GBRWrapperRcd>().get(forestLabel_, forestHandle);
36  forest_ = forestHandle.product();
37  }
38  }
39 
40  float operator()(reco::Track const &trk,
41  reco::BeamSpot const &beamSpot,
42  reco::VertexCollection const &vertices) const {
43  auto tmva_pt_ = trk.pt();
44  auto tmva_ndof_ = trk.ndof();
45  auto tmva_nlayers_ = trk.hitPattern().trackerLayersWithMeasurement();
46  auto tmva_nlayers3D_ =
49  float chi2n = trk.normalizedChi2();
50  float chi2n_no1Dmod = chi2n;
51 
52  int count1dhits = 0;
53  for (auto ith = trk.recHitsBegin(); ith != trk.recHitsEnd(); ++ith) {
54  const auto &hit = *(*ith);
55  if (hit.dimension() == 1)
56  ++count1dhits;
57  }
58 
59  if (count1dhits > 0) {
60  float chi2 = trk.chi2();
61  float ndof = trk.ndof();
62  chi2n = (chi2 + count1dhits) / float(ndof + count1dhits);
63  }
64  auto tmva_chi2n_ = chi2n;
65  auto tmva_chi2n_no1dmod_ = chi2n_no1Dmod;
66  auto tmva_eta_ = trk.eta();
67  auto tmva_relpterr_ = float(trk.ptError()) / std::max(float(trk.pt()), 0.000001f);
68  auto tmva_nhits_ = trk.numberOfValidHits();
71  auto tmva_minlost_ = std::min(lostIn, lostOut);
72  auto tmva_lostmidfrac_ = static_cast<float>(trk.numberOfLostHits()) /
73  static_cast<float>(trk.numberOfValidHits() + trk.numberOfLostHits());
74 
75  float gbrVals_[PROMPT ? 16 : 12];
76  gbrVals_[0] = tmva_pt_;
77  gbrVals_[1] = tmva_lostmidfrac_;
78  gbrVals_[2] = tmva_minlost_;
79  gbrVals_[3] = tmva_nhits_;
80  gbrVals_[4] = tmva_relpterr_;
81  gbrVals_[5] = tmva_eta_;
82  gbrVals_[6] = tmva_chi2n_no1dmod_;
83  gbrVals_[7] = tmva_chi2n_;
84  gbrVals_[8] = tmva_nlayerslost_;
85  gbrVals_[9] = tmva_nlayers3D_;
86  gbrVals_[10] = tmva_nlayers_;
87  gbrVals_[11] = tmva_ndof_;
88 
89  if (PROMPT) {
90  auto tmva_absd0_ = std::abs(trk.dxy(beamSpot.position()));
91  auto tmva_absdz_ = std::abs(trk.dz(beamSpot.position()));
92  Point bestVertex = getBestVertex(trk, vertices);
93  auto tmva_absd0PV_ = std::abs(trk.dxy(bestVertex));
94  auto tmva_absdzPV_ = std::abs(trk.dz(bestVertex));
95 
96  gbrVals_[12] = tmva_absd0PV_;
97  gbrVals_[13] = tmva_absdzPV_;
98  gbrVals_[14] = tmva_absdz_;
99  gbrVals_[15] = tmva_absd0_;
100  }
101 
102  return forest_->GetClassifier(gbrVals_);
103  }
104 
105  static const char *name();
106 
108  desc.add<std::string>("GBRForestLabel", std::string());
109  desc.add<std::string>("GBRForestFileName", std::string());
110  }
111 
112  std::unique_ptr<GBRForest> forestFromFile_;
113  const GBRForest *forest_ = nullptr; // owned by somebody else
114  const std::string forestLabel_;
115  const std::string dbFileName_;
116  const bool useForestFromDB_;
117  };
118 
119  using TrackMVAClassifierDetached = TrackMVAClassifier<mva<false>>;
120  using TrackMVAClassifierPrompt = TrackMVAClassifier<mva<true>>;
121  template <>
122  const char *mva<false>::name() {
123  return "TrackMVAClassifierDetached";
124  }
125  template <>
126  const char *mva<true>::name() {
127  return "TrackMVAClassifierPrompt";
128  }
129 
130 } // namespace
131 
134 
135 DEFINE_FWK_MODULE(TrackMVAClassifierDetached);
136 DEFINE_FWK_MODULE(TrackMVAClassifierPrompt);
double normalizedChi2() const
chi-squared divided by n.d.o.f. (or chi-squared * 1e6 if n.d.o.f. is zero)
Definition: TrackBase.h:572
unsigned short numberOfLostHits() const
number of cases where track crossed a layer without getting a hit.
Definition: TrackBase.h:743
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
int pixelLayersWithMeasurement() const
Definition: HitPattern.cc:492
int trackerLayersWithMeasurement() const
Definition: HitPattern.cc:513
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
double eta() const
pseudorapidity of momentum vector
Definition: TrackBase.h:617
int numberOfValidStripLayersWithMonoAndStereo(uint16_t stripdet, uint16_t layer) const
Definition: HitPattern.cc:348
double chi2() const
chi-squared of the fit
Definition: TrackBase.h:566
double ndof() const
number of degrees of freedom of the fit
Definition: TrackBase.h:569
double pt() const
track transverse momentum
Definition: TrackBase.h:602
double ptError() const
error on Pt (set to 1000 TeV if charge==0 for safety)
Definition: TrackBase.h:696
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
double f[11][100]
unsigned short numberOfValidHits() const
number of valid hits found
Definition: TrackBase.h:740
T min(T a, T b)
Definition: MathUtil.h:58
trackingRecHit_iterator recHitsBegin() const
Iterator to first hit on the track.
Definition: Track.h:88
ParameterDescriptionBase * add(U const &iLabel, T const &value)
double dz() const
dz parameter (= dsz/cos(lambda)). This is the track z0 w.r.t (0,0,0) only if the refPoint is close to...
Definition: TrackBase.h:596
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
const HitPattern & hitPattern() const
Access the hit pattern, indicating in which Tracker layers the track has hits.
Definition: TrackBase.h:483
int numberOfLostHits(HitCategory category) const
Definition: HitPattern.h:861
int trackerLayersWithoutMeasurement(HitCategory category) const
Definition: HitPattern.cc:532
Structure Point Contains parameters of Gaussian fits to DMRs.
Definition: DMRtrends.cc:57
T get() const
Definition: EventSetup.h:73
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:8
const Point & position() const
position
Definition: BeamSpot.h:59
double dxy() const
dxy parameter. (This is the transverse impact parameter w.r.t. to (0,0,0) ONLY if refPoint is close t...
Definition: TrackBase.h:587
T const * product() const
Definition: ESHandle.h:86
trackingRecHit_iterator recHitsEnd() const
Iterator to last hit on the track.
Definition: Track.h:91