CMS 3D CMS Logo

DefaultTrackMVAClassifier.cc
Go to the documentation of this file.
2 
6 
9 #include <limits>
10 
11 #include "getBestVertex.h"
12 
13 #include "TFile.h"
14 
15 namespace {
16 
17  template <bool PROMPT>
18  struct mva {
19  mva(const edm::ParameterSet &cfg)
20  : forestLabel_(cfg.getParameter<std::string>("GBRForestLabel")),
21  dbFileName_(cfg.getParameter<std::string>("GBRForestFileName")),
22  useForestFromDB_((!forestLabel_.empty()) & dbFileName_.empty()) {}
23 
24  void beginStream() {
25  if (!dbFileName_.empty()) {
26  TFile gbrfile(dbFileName_.c_str());
27  forestFromFile_.reset((GBRForest *)gbrfile.Get(forestLabel_.c_str()));
28  }
29  }
30 
31  void initEvent(const edm::EventSetup &es) {
32  forest_ = forestFromFile_.get();
33  if (useForestFromDB_) {
34  edm::ESHandle<GBRForest> forestHandle;
35  es.get<GBRWrapperRcd>().get(forestLabel_, forestHandle);
36  forest_ = forestHandle.product();
37  }
38  }
39 
40  float operator()(reco::Track const &trk,
41  reco::BeamSpot const &beamSpot,
42  reco::VertexCollection const &vertices) const {
43  auto tmva_pt_ = trk.pt();
44  auto tmva_ndof_ = trk.ndof();
45  auto tmva_nlayers_ = trk.hitPattern().trackerLayersWithMeasurement();
46  auto tmva_nlayers3D_ =
49  float chi2n = trk.normalizedChi2();
50  float chi2n_no1Dmod = chi2n;
51 
52  int count1dhits = 0;
53  for (auto ith = trk.recHitsBegin(); ith != trk.recHitsEnd(); ++ith) {
54  const auto &hit = *(*ith);
55  if (hit.dimension() == 1)
56  ++count1dhits;
57  }
58 
59  if (count1dhits > 0) {
60  float chi2 = trk.chi2();
61  float ndof = trk.ndof();
62  chi2n = (chi2 + count1dhits) / float(ndof + count1dhits);
63  }
64  auto tmva_chi2n_ = chi2n;
65  auto tmva_chi2n_no1dmod_ = chi2n_no1Dmod;
66  auto tmva_eta_ = trk.eta();
67  auto tmva_relpterr_ = float(trk.ptError()) / std::max(float(trk.pt()), 0.000001f);
68  auto tmva_nhits_ = trk.numberOfValidHits();
71  auto tmva_minlost_ = std::min(lostIn, lostOut);
72  auto tmva_lostmidfrac_ = static_cast<float>(trk.numberOfLostHits()) /
73  static_cast<float>(trk.numberOfValidHits() + trk.numberOfLostHits());
74 
75  float gbrVals_[PROMPT ? 16 : 12];
76  gbrVals_[0] = tmva_pt_;
77  gbrVals_[1] = tmva_lostmidfrac_;
78  gbrVals_[2] = tmva_minlost_;
79  gbrVals_[3] = tmva_nhits_;
80  gbrVals_[4] = tmva_relpterr_;
81  gbrVals_[5] = tmva_eta_;
82  gbrVals_[6] = tmva_chi2n_no1dmod_;
83  gbrVals_[7] = tmva_chi2n_;
84  gbrVals_[8] = tmva_nlayerslost_;
85  gbrVals_[9] = tmva_nlayers3D_;
86  gbrVals_[10] = tmva_nlayers_;
87  gbrVals_[11] = tmva_ndof_;
88 
89  if (PROMPT) {
90  auto tmva_absd0_ = std::abs(trk.dxy(beamSpot.position()));
91  auto tmva_absdz_ = std::abs(trk.dz(beamSpot.position()));
92  Point bestVertex = getBestVertex(trk, vertices);
93  auto tmva_absd0PV_ = std::abs(trk.dxy(bestVertex));
94  auto tmva_absdzPV_ = std::abs(trk.dz(bestVertex));
95 
96  gbrVals_[12] = tmva_absd0PV_;
97  gbrVals_[13] = tmva_absdzPV_;
98  gbrVals_[14] = tmva_absdz_;
99  gbrVals_[15] = tmva_absd0_;
100  }
101 
102  return forest_->GetClassifier(gbrVals_);
103  }
104 
105  static const char *name();
106 
108  desc.add<std::string>("GBRForestLabel", std::string());
109  desc.add<std::string>("GBRForestFileName", std::string());
110  }
111 
112  std::unique_ptr<GBRForest> forestFromFile_;
113  const GBRForest *forest_ = nullptr; // owned by somebody else
114  const std::string forestLabel_;
115  const std::string dbFileName_;
116  const bool useForestFromDB_;
117  };
118 
119  using TrackMVAClassifierDetached = TrackMVAClassifier<mva<false>>;
120  using TrackMVAClassifierPrompt = TrackMVAClassifier<mva<true>>;
121  template <>
122  const char *mva<false>::name() {
123  return "TrackMVAClassifierDetached";
124  }
125  template <>
126  const char *mva<true>::name() {
127  return "TrackMVAClassifierPrompt";
128  }
129 
130 } // namespace
131 
134 
135 DEFINE_FWK_MODULE(TrackMVAClassifierDetached);
136 DEFINE_FWK_MODULE(TrackMVAClassifierPrompt);
edm::ESHandle::product
T const * product() const
Definition: ESHandle.h:86
reco::HitPattern::MISSING_OUTER_HITS
Definition: HitPattern.h:155
pwdgSkimBPark_cfi.beamSpot
beamSpot
Definition: pwdgSkimBPark_cfi.py:5
reco::TrackBase::ptError
double ptError() const
error on Pt (set to 1000 TeV if charge==0 for safety)
Definition: TrackBase.h:754
dqmMemoryStats.float
float
Definition: dqmMemoryStats.py:127
chi2n
Definition: HIMultiTrackSelector.h:45
reco::HitPattern::trackerLayersWithoutMeasurement
int trackerLayersWithoutMeasurement(HitCategory category) const
Definition: HitPattern.cc:532
ESHandle.h
reco::Track::recHitsBegin
trackingRecHit_iterator recHitsBegin() const
Iterator to first hit on the track.
Definition: Track.h:88
f
double f[11][100]
Definition: MuScleFitUtils.cc:78
min
T min(T a, T b)
Definition: MathUtil.h:58
GBRWrapperRcd.h
GBRForest
Definition: GBRForest.h:25
reco::VertexCollection
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
edm::ParameterSetDescription
Definition: ParameterSetDescription.h:52
getBestVertex.h
reco::TrackBase::ndof
double ndof() const
number of degrees of freedom of the fit
Definition: TrackBase.h:590
hltPixelTracks_cff.chi2
chi2
Definition: hltPixelTracks_cff.py:25
beam_dqm_sourceclient-live_cfg.mva
mva
Definition: beam_dqm_sourceclient-live_cfg.py:122
reco::HitPattern::numberOfLostHits
int numberOfLostHits(HitCategory category) const
Definition: HitPattern.h:860
reco::TrackBase::numberOfValidHits
unsigned short numberOfValidHits() const
number of valid hits found
Definition: TrackBase.h:798
reco::HitPattern::pixelLayersWithMeasurement
int pixelLayersWithMeasurement() const
Definition: HitPattern.cc:492
ndof
Definition: HIMultiTrackSelector.h:49
reco::TrackBase::pt
double pt() const
track transverse momentum
Definition: TrackBase.h:637
TrackMVAClassifier.h
MakerMacros.h
TrackMVAClassifier
Definition: TrackMVAClassifier.h:95
Track.h
edm::EventSetup::get
T get() const
Definition: EventSetup.h:80
DEFINE_FWK_MODULE
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
reco::HitPattern::trackerLayersWithMeasurement
int trackerLayersWithMeasurement() const
Definition: HitPattern.cc:513
fillDescriptions
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
reco::Track::recHitsEnd
trackingRecHit_iterator recHitsEnd() const
Iterator to last hit on the track.
Definition: Track.h:91
reco::BeamSpot
Definition: BeamSpot.h:21
reco::TrackBase::numberOfLostHits
unsigned short numberOfLostHits() const
number of cases where track crossed a layer without getting a hit.
Definition: TrackBase.h:801
reco::Track
Definition: Track.h:27
edm::ESHandle< GBRForest >
reco::TrackBase::dz
double dz() const
dz parameter (= dsz/cos(lambda)). This is the track z0 w.r.t (0,0,0) only if the refPoint is close to...
Definition: TrackBase.h:622
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
Vertex.h
Point
Structure Point Contains parameters of Gaussian fits to DMRs.
Definition: DMRtrends.cc:57
edm::ParameterSet
Definition: ParameterSet.h:47
SiStripPI::max
Definition: SiStripPayloadInspectorHelper.h:169
reco::TrackBase::eta
double eta() const
pseudorapidity of momentum vector
Definition: TrackBase.h:652
ModuleDef.h
edm::EventSetup
Definition: EventSetup.h:57
reco::TrackBase::chi2
double chi2() const
chi-squared of the fit
Definition: TrackBase.h:587
reco::TrackBase::normalizedChi2
double normalizedChi2() const
chi-squared divided by n.d.o.f. (or chi-squared * 1e6 if n.d.o.f. is zero)
Definition: TrackBase.h:593
get
#define get
reco::HitPattern::TRACK_HITS
Definition: HitPattern.h:155
looper.cfg
cfg
Definition: looper.py:297
reco::TrackBase::hitPattern
const HitPattern & hitPattern() const
Access the hit pattern, indicating in which Tracker layers the track has hits.
Definition: TrackBase.h:504
submitPVResolutionJobs.desc
string desc
Definition: submitPVResolutionJobs.py:251
std
Definition: JetResolutionObject.h:76
relativeConstraints.empty
bool empty
Definition: relativeConstraints.py:46
reco::HitPattern::MISSING_INNER_HITS
Definition: HitPattern.h:155
GBRWrapperRcd
Definition: GBRWrapperRcd.h:24
Skims_PA_cff.name
name
Definition: Skims_PA_cff.py:17
EventSetup.h
funct::abs
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
reco::TrackBase::dxy
double dxy() const
dxy parameter. (This is the transverse impact parameter w.r.t. to (0,0,0) ONLY if refPoint is close t...
Definition: TrackBase.h:608
hit
Definition: SiStripHitEffFromCalibTree.cc:88
reco::HitPattern::numberOfValidStripLayersWithMonoAndStereo
int numberOfValidStripLayersWithMonoAndStereo(uint16_t stripdet, uint16_t layer) const
Definition: HitPattern.cc:348
getBestVertex
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:8
pwdgSkimBPark_cfi.vertices
vertices
Definition: pwdgSkimBPark_cfi.py:7