CMS 3D CMS Logo

TrackLwtnnClassifier.cc
Go to the documentation of this file.
2 
5 
8 
9 #include "getBestVertex.h"
10 
11 //from lwtnn
13 #include "lwtnn/LightweightNeuralNetwork.hh"
14 
15 namespace {
16  struct lwtnn {
17  lwtnn(const edm::ParameterSet& cfg) : lwtnnLabel_(cfg.getParameter<std::string>("lwtnnLabel")) {}
18 
19  static const char* name() { return "TrackLwtnnClassifier"; }
20 
22  desc.add<std::string>("lwtnnLabel", "trackSelectionLwtnn");
23  }
24 
25  void beginStream() {}
26  void initEvent(const edm::EventSetup& es) {
28  es.get<TrackingComponentsRecord>().get(lwtnnLabel_, lwtnnHandle);
29  neuralNetwork_ = lwtnnHandle.product();
30  }
31 
32  std::pair<float, bool> operator()(reco::Track const& trk,
33  reco::BeamSpot const& beamSpot,
35  lwt::ValueMap& inputs) const {
36  // lwt::ValueMap is typedef for std::map<std::string, double>
37  //
38  // It is cached per event to avoid constructing the map for each
39  // track while keeping the operator() interface const.
40 
41  Point bestVertex = getBestVertex(trk, vertices);
42 
43  inputs["trk_pt"] = trk.pt();
44  inputs["trk_eta"] = trk.eta();
45  inputs["trk_lambda"] = trk.lambda();
46  inputs["trk_dxy"] = trk.dxy(beamSpot.position()); // Training done without taking absolute value
47  inputs["trk_dz"] = trk.dz(beamSpot.position()); // Training done without taking absolute value
48  inputs["trk_dxyClosestPV"] = trk.dxy(bestVertex); // Training done without taking absolute value
49  // Training done without taking absolute value
50  inputs["trk_dzClosestPVNorm"] = std::max(-0.2, std::min(trk.dz(bestVertex), 0.2));
51  inputs["trk_ptErr"] = trk.ptError();
52  inputs["trk_etaErr"] = trk.etaError();
53  inputs["trk_lambdaErr"] = trk.lambdaError();
54  inputs["trk_dxyErr"] = trk.dxyError();
55  inputs["trk_dzErr"] = trk.dzError();
56  inputs["trk_nChi2"] = trk.normalizedChi2();
57  inputs["trk_ndof"] = trk.ndof();
59  inputs["trk_nPixel"] = trk.hitPattern().numberOfValidPixelHits();
60  inputs["trk_nStrip"] = trk.hitPattern().numberOfValidStripHits();
61  inputs["trk_nPixelLay"] = trk.hitPattern().pixelLayersWithMeasurement();
62  inputs["trk_nStripLay"] = trk.hitPattern().stripLayersWithMeasurement();
66  inputs["trk_algo"] = trk.algo();
67 
68  auto out = neuralNetwork_->compute(inputs);
69  // there should only one output
70  if (out.size() != 1)
71  throw cms::Exception("LogicError") << "Expecting exactly one output from NN, got " << out.size();
72 
73  float output = 2.0 * out.begin()->second - 1.0;
74 
75  //Check if the network is known to be unreliable in that part of phase space. Hard cut values
76  //correspond to rare tracks known to be difficult for the Deep Neural Network classifier
77 
78  bool isReliable = true;
79  //T1qqqq
80  if (std::abs(inputs["trk_dxy"]) >= 0.1 && inputs["trk_etaErr"] < 0.003 && inputs["trk_dxyErr"] < 0.03 &&
81  inputs["trk_ndof"] > 3) {
82  isReliable = false;
83  }
84  //T5qqqqLL
85  if ((inputs["trk_pt"] > 100.0) && (inputs["trk_nChi2"] < 4.0) && (inputs["trk_etaErr"] < 0.001)) {
86  isReliable = false;
87  }
88 
89  std::pair<float, bool> return_(output, isReliable);
90  return return_;
91  }
92 
93  std::string lwtnnLabel_;
94  const lwt::LightweightNeuralNetwork* neuralNetwork_;
95  };
96 
97  using TrackLwtnnClassifier = TrackMVAClassifier<lwtnn, lwt::ValueMap>;
98 } // namespace
99 
102 
103 DEFINE_FWK_MODULE(TrackLwtnnClassifier);
edm::ESHandle::product
T const * product() const
Definition: ESHandle.h:86
edm::ParameterSetDescription::add
ParameterDescriptionBase * add(U const &iLabel, T const &value)
Definition: ParameterSetDescription.h:95
pwdgSkimBPark_cfi.beamSpot
beamSpot
Definition: pwdgSkimBPark_cfi.py:5
reco::TrackBase::ptError
double ptError() const
error on Pt (set to 1000 TeV if charge==0 for safety)
Definition: TrackBase.h:702
reco::TrackBase::lambdaError
double lambdaError() const
error on lambda
Definition: TrackBase.h:713
reco::HitPattern::trackerLayersWithoutMeasurement
int trackerLayersWithoutMeasurement(HitCategory category) const
Definition: HitPattern.cc:532
ESHandle.h
reco::TrackBase::etaError
double etaError() const
error on eta
Definition: TrackBase.h:716
convertSQLitetoXML_cfg.output
output
Definition: convertSQLitetoXML_cfg.py:32
min
T min(T a, T b)
Definition: MathUtil.h:58
reco::VertexCollection
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
edm::ParameterSetDescription
Definition: ParameterSetDescription.h:52
getBestVertex.h
reco::TrackBase::ndof
double ndof() const
number of degrees of freedom of the fit
Definition: TrackBase.h:575
reco::HitPattern::numberOfLostHits
int numberOfLostHits(HitCategory category) const
Definition: HitPattern.h:860
reco::TrackBase::dxyError
double dxyError() const
error on dxy
Definition: TrackBase.h:722
reco::HitPattern::pixelLayersWithMeasurement
int pixelLayersWithMeasurement() const
Definition: HitPattern.cc:492
reco::TrackBase::pt
double pt() const
track transverse momentum
Definition: TrackBase.h:608
TrackMVAClassifier.h
MakerMacros.h
TrackMVAClassifier
Definition: TrackMVAClassifier.h:95
reco::HitPattern::stripLayersWithMeasurement
int stripLayersWithMeasurement() const
Definition: HitPattern.h:974
Track.h
edm::EventSetup::get
T get() const
Definition: EventSetup.h:73
DEFINE_FWK_MODULE
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
fillDescriptions
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
reco::BeamSpot
Definition: BeamSpot.h:21
reco::Track
Definition: Track.h:27
edm::ESHandle
Definition: DTSurvey.h:22
reco::TrackBase::dz
double dz() const
dz parameter (= dsz/cos(lambda)). This is the track z0 w.r.t (0,0,0) only if the refPoint is close to...
Definition: TrackBase.h:602
AlCaHLTBitMon_QueryRunRegistry.string
string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
Vertex.h
Point
Structure Point Contains parameters of Gaussian fits to DMRs.
Definition: DMRtrends.cc:57
edm::ParameterSet
Definition: ParameterSet.h:36
SiStripPI::max
Definition: SiStripPayloadInspectorHelper.h:169
reco::TrackBase::eta
double eta() const
pseudorapidity of momentum vector
Definition: TrackBase.h:623
reco::TrackBase::dzError
double dzError() const
error on dz
Definition: TrackBase.h:731
ModuleDef.h
PixelMapPlotter.inputs
inputs
Definition: PixelMapPlotter.py:490
edm::EventSetup
Definition: EventSetup.h:57
reco::TrackBase::normalizedChi2
double normalizedChi2() const
chi-squared divided by n.d.o.f. (or chi-squared * 1e6 if n.d.o.f. is zero)
Definition: TrackBase.h:578
get
#define get
reco::HitPattern::TRACK_HITS
Definition: HitPattern.h:155
looper.cfg
cfg
Definition: looper.py:297
reco::TrackBase::algo
TrackAlgorithm algo() const
Definition: TrackBase.h:532
reco::TrackBase::hitPattern
const HitPattern & hitPattern() const
Access the hit pattern, indicating in which Tracker layers the track has hits.
Definition: TrackBase.h:489
std
Definition: JetResolutionObject.h:76
reco::HitPattern::numberOfValidStripHits
int numberOfValidStripHits() const
Definition: HitPattern.h:812
TrackingComponentsRecord.h
Skims_PA_cff.name
name
Definition: Skims_PA_cff.py:17
EventSetup.h
reco::HitPattern::numberOfValidPixelHits
int numberOfValidPixelHits() const
Definition: HitPattern.h:800
MillePedeFileConverter_cfg.out
out
Definition: MillePedeFileConverter_cfg.py:31
cms::Exception
Definition: Exception.h:70
funct::abs
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
reco::TrackBase::lambda
double lambda() const
Lambda angle.
Definition: TrackBase.h:590
reco::TrackBase::dxy
double dxy() const
dxy parameter. (This is the transverse impact parameter w.r.t. to (0,0,0) ONLY if refPoint is close t...
Definition: TrackBase.h:593
reco::HitPattern::numberOfValidStripLayersWithMonoAndStereo
int numberOfValidStripLayersWithMonoAndStereo(uint16_t stripdet, uint16_t layer) const
Definition: HitPattern.cc:348
getBestVertex
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:8
pwdgSkimBPark_cfi.vertices
vertices
Definition: pwdgSkimBPark_cfi.py:7
TrackingComponentsRecord
Definition: TrackingComponentsRecord.h:12