CMS 3D CMS Logo

TrackLwtnnClassifier.cc
Go to the documentation of this file.
2 
5 
8 
9 #include "getBestVertex.h"
10 
11 //from lwtnn
13 #include "lwtnn/LightweightNeuralNetwork.hh"
14 
15 namespace {
16  struct lwtnn {
17  lwtnn(const edm::ParameterSet& cfg) : lwtnnLabel_(cfg.getParameter<std::string>("lwtnnLabel")) {}
18 
19  static const char* name() { return "TrackLwtnnClassifier"; }
20 
22  desc.add<std::string>("lwtnnLabel", "trackSelectionLwtnn");
23  }
24 
25  void beginStream() {}
26  void initEvent(const edm::EventSetup& es) {
28  es.get<TrackingComponentsRecord>().get(lwtnnLabel_, lwtnnHandle);
29  neuralNetwork_ = lwtnnHandle.product();
30  }
31 
32  std::pair<float, bool> operator()(reco::Track const& trk,
33  reco::BeamSpot const& beamSpot,
35  lwt::ValueMap& inputs) const {
36  // lwt::ValueMap is typedef for std::map<std::string, double>
37  //
38  // It is cached per event to avoid constructing the map for each
39  // track while keeping the operator() interface const.
40 
41  Point bestVertex = getBestVertex(trk, vertices);
42 
43  inputs["trk_pt"] = trk.pt();
44  inputs["trk_eta"] = trk.eta();
45  inputs["trk_lambda"] = trk.lambda();
46  inputs["trk_dxy"] = trk.dxy(beamSpot.position()); // Training done without taking absolute value
47  inputs["trk_dz"] = trk.dz(beamSpot.position()); // Training done without taking absolute value
48  inputs["trk_dxyClosestPV"] = trk.dxy(bestVertex); // Training done without taking absolute value
49  // Training done without taking absolute value
50  inputs["trk_dzClosestPVNorm"] = std::max(-0.2, std::min(trk.dz(bestVertex), 0.2));
51  inputs["trk_ptErr"] = trk.ptError();
52  inputs["trk_etaErr"] = trk.etaError();
53  inputs["trk_lambdaErr"] = trk.lambdaError();
54  inputs["trk_dxyErr"] = trk.dxyError();
55  inputs["trk_dzErr"] = trk.dzError();
56  inputs["trk_nChi2"] = trk.normalizedChi2();
57  inputs["trk_ndof"] = trk.ndof();
58  inputs["trk_nInvalid"] = trk.hitPattern().numberOfLostHits(reco::HitPattern::TRACK_HITS);
59  inputs["trk_nPixel"] = trk.hitPattern().numberOfValidPixelHits();
60  inputs["trk_nStrip"] = trk.hitPattern().numberOfValidStripHits();
61  inputs["trk_nPixelLay"] = trk.hitPattern().pixelLayersWithMeasurement();
62  inputs["trk_nStripLay"] = trk.hitPattern().stripLayersWithMeasurement();
63  inputs["trk_n3DLay"] = (trk.hitPattern().numberOfValidStripLayersWithMonoAndStereo() +
66  inputs["trk_algo"] = trk.algo();
67 
68  auto out = neuralNetwork_->compute(inputs);
69  // there should only one output
70  if (out.size() != 1)
71  throw cms::Exception("LogicError") << "Expecting exactly one output from NN, got " << out.size();
72 
73  float output = 2.0 * out.begin()->second - 1.0;
74 
75  //Check if the network is known to be unreliable in that part of phase space. Hard cut values
76  //correspond to rare tracks known to be difficult for the Deep Neural Network classifier
77 
78  bool isReliable = true;
79  //T1qqqq
80  if (std::abs(inputs["trk_dxy"]) >= 0.1 && inputs["trk_etaErr"] < 0.003 && inputs["trk_dxyErr"] < 0.03 &&
81  inputs["trk_ndof"] > 3) {
82  isReliable = false;
83  }
84  //T5qqqqLL
85  if ((inputs["trk_pt"] > 100.0) && (inputs["trk_nChi2"] < 4.0) && (inputs["trk_etaErr"] < 0.001)) {
86  isReliable = false;
87  }
88 
89  std::pair<float, bool> return_(output, isReliable);
90  return return_;
91  }
92 
93  std::string lwtnnLabel_;
94  const lwt::LightweightNeuralNetwork* neuralNetwork_;
95  };
96 
97  using TrackLwtnnClassifier = TrackMVAClassifier<lwtnn, lwt::ValueMap>;
98 } // namespace
99 
102 
103 DEFINE_FWK_MODULE(TrackLwtnnClassifier);
double normalizedChi2() const
chi-squared divided by n.d.o.f. (or chi-squared * 1e6 if n.d.o.f. is zero)
Definition: TrackBase.h:572
double dxyError() const
error on dxy
Definition: TrackBase.h:716
double etaError() const
error on eta
Definition: TrackBase.h:710
int numberOfValidStripHits() const
Definition: HitPattern.h:813
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
int pixelLayersWithMeasurement() const
Definition: HitPattern.cc:492
TrackAlgorithm algo() const
Definition: TrackBase.h:526
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
double eta() const
pseudorapidity of momentum vector
Definition: TrackBase.h:617
int numberOfValidStripLayersWithMonoAndStereo(uint16_t stripdet, uint16_t layer) const
Definition: HitPattern.cc:348
double ndof() const
number of degrees of freedom of the fit
Definition: TrackBase.h:569
double pt() const
track transverse momentum
Definition: TrackBase.h:602
double ptError() const
error on Pt (set to 1000 TeV if charge==0 for safety)
Definition: TrackBase.h:696
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
double lambda() const
Lambda angle.
Definition: TrackBase.h:584
T min(T a, T b)
Definition: MathUtil.h:58
ParameterDescriptionBase * add(U const &iLabel, T const &value)
double dz() const
dz parameter (= dsz/cos(lambda)). This is the track z0 w.r.t (0,0,0) only if the refPoint is close to...
Definition: TrackBase.h:596
double dzError() const
error on dz
Definition: TrackBase.h:725
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
const HitPattern & hitPattern() const
Access the hit pattern, indicating in which Tracker layers the track has hits.
Definition: TrackBase.h:483
int stripLayersWithMeasurement() const
Definition: HitPattern.h:975
int numberOfLostHits(HitCategory category) const
Definition: HitPattern.h:861
int trackerLayersWithoutMeasurement(HitCategory category) const
Definition: HitPattern.cc:532
double lambdaError() const
error on lambda
Definition: TrackBase.h:707
Structure Point Contains parameters of Gaussian fits to DMRs.
Definition: DMRtrends.cc:57
T get() const
Definition: EventSetup.h:73
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:8
int numberOfValidPixelHits() const
Definition: HitPattern.h:801
const Point & position() const
position
Definition: BeamSpot.h:59
double dxy() const
dxy parameter. (This is the transverse impact parameter w.r.t. to (0,0,0) ONLY if refPoint is close t...
Definition: TrackBase.h:587
T const * product() const
Definition: ESHandle.h:86