CMS 3D CMS Logo

TrackLwtnnClassifier.cc
Go to the documentation of this file.
2 
5 
8 
9 #include "getBestVertex.h"
10 
11 //from lwtnn
13 #include "lwtnn/LightweightNeuralNetwork.hh"
14 
15 namespace {
16  struct lwtnn {
17  lwtnn(const edm::ParameterSet& cfg):
18  lwtnnLabel_(cfg.getParameter<std::string>("lwtnnLabel"))
19  {}
20 
21  static const char *name() { return "TrackLwtnnClassifier"; }
22 
24  desc.add<std::string>("lwtnnLabel", "trackSelectionLwtnn");
25  }
26 
27  void beginStream() {}
28  void initEvent(const edm::EventSetup& es) {
30  es.get<TrackingComponentsRecord>().get(lwtnnLabel_, lwtnnHandle);
31  neuralNetwork_ = lwtnnHandle.product();
32  }
33 
34  std::pair<float,bool> operator()(reco::Track const & trk,
35  reco::BeamSpot const & beamSpot,
37  lwt::ValueMap & inputs) const {
38  // lwt::ValueMap is typedef for std::map<std::string, double>
39  //
40  // It is cached per event to avoid constructing the map for each
41  // track while keeping the operator() interface const.
42 
43  Point bestVertex = getBestVertex(trk,vertices);
44 
45  inputs["trk_pt"] = trk.pt();
46  inputs["trk_eta"] = trk.eta();
47  inputs["trk_lambda"] = trk.lambda();
48  inputs["trk_dxy"] = trk.dxy(beamSpot.position()); // Training done without taking absolute value
49  inputs["trk_dz"] = trk.dz(beamSpot.position()); // Training done without taking absolute value
50  inputs["trk_dxyClosestPV"] = trk.dxy(bestVertex); // Training done without taking absolute value
51  inputs["trk_dzClosestPVNorm"] = std::max(-0.2, std::min(trk.dz(bestVertex), 0.2)); // Training done without taking absolute value
52  inputs["trk_ptErr"] = trk.ptError();
53  inputs["trk_etaErr"] = trk.etaError();
54  inputs["trk_lambdaErr"] = trk.lambdaError();
55  inputs["trk_dxyErr"] = trk.dxyError();
56  inputs["trk_dzErr"] = trk.dzError();
57  inputs["trk_nChi2"] = trk.normalizedChi2();
58  inputs["trk_ndof"] = trk.ndof();
59  inputs["trk_nInvalid"] = trk.hitPattern().numberOfLostHits(reco::HitPattern::TRACK_HITS);
60  inputs["trk_nPixel"] = trk.hitPattern().numberOfValidPixelHits();
61  inputs["trk_nStrip"] = trk.hitPattern().numberOfValidStripHits();
62  inputs["trk_nPixelLay"] = trk.hitPattern().pixelLayersWithMeasurement();
63  inputs["trk_nStripLay"] = trk.hitPattern().stripLayersWithMeasurement();
66  inputs["trk_algo"] = trk.algo();
67 
68  auto out = neuralNetwork_->compute(inputs);
69  // there should only one output
70  if(out.size() != 1) throw cms::Exception("LogicError") << "Expecting exactly one output from NN, got " << out.size();
71 
72  float output = 2.0*out.begin()->second-1.0;
73 
74 
75  //Check if the network is known to be unreliable in that part of phase space. Hard cut values
76  //correspond to rare tracks known to be difficult for the Deep Neural Network classifier
77 
78  bool isReliable = true;
79  //T1qqqq
80  if(std::abs(inputs["trk_dxy"])>=0.1 && inputs["trk_etaErr"]<0.003 && inputs["trk_dxyErr"]<0.03 && inputs["trk_ndof"]>3){
81  isReliable = false;
82  }
83  //T5qqqqLL
84  if((inputs["trk_pt"]>100.0)&&(inputs["trk_nChi2"]<4.0)&&(inputs["trk_etaErr"]<0.001)){
85  isReliable = false;
86  }
87 
88  std::pair<float, bool> return_ (output, isReliable);
89  return return_;
90  }
91 
92 
93  std::string lwtnnLabel_;
94  const lwt::LightweightNeuralNetwork *neuralNetwork_;
95  };
96 
97  using TrackLwtnnClassifier = TrackMVAClassifier<lwtnn, lwt::ValueMap>;
98 }
99 
102 
103 DEFINE_FWK_MODULE(TrackLwtnnClassifier);
104 
double normalizedChi2() const
chi-squared divided by n.d.o.f. (or chi-squared * 1e6 if n.d.o.f. is zero)
Definition: TrackBase.h:600
double dxyError() const
error on dxy
Definition: TrackBase.h:847
double etaError() const
error on eta
Definition: TrackBase.h:835
int numberOfValidStripHits() const
Definition: HitPattern.h:931
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
int pixelLayersWithMeasurement() const
Definition: HitPattern.cc:538
TrackAlgorithm algo() const
Definition: TrackBase.h:536
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
double eta() const
pseudorapidity of momentum vector
Definition: TrackBase.h:690
int numberOfValidStripLayersWithMonoAndStereo(uint16_t stripdet, uint16_t layer) const
Definition: HitPattern.cc:379
double ndof() const
number of degrees of freedom of the fit
Definition: TrackBase.h:594
double pt() const
track transverse momentum
Definition: TrackBase.h:660
double ptError() const
error on Pt (set to 1000 TeV if charge==0 for safety)
Definition: TrackBase.h:814
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
double lambda() const
Lambda angle.
Definition: TrackBase.h:624
T min(T a, T b)
Definition: MathUtil.h:58
ParameterDescriptionBase * add(U const &iLabel, T const &value)
double dz() const
dz parameter (= dsz/cos(lambda)). This is the track z0 w.r.t (0,0,0) only if the refPoint is close to...
Definition: TrackBase.h:648
double dzError() const
error on dz
Definition: TrackBase.h:865
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
const HitPattern & hitPattern() const
Access the hit pattern, indicating in which Tracker layers the track has hits.
Definition: TrackBase.h:485
int stripLayersWithMeasurement() const
Definition: HitPattern.h:1138
int numberOfLostHits(HitCategory category) const
Definition: HitPattern.h:990
int trackerLayersWithoutMeasurement(HitCategory category) const
Definition: HitPattern.cc:574
double lambdaError() const
error on lambda
Definition: TrackBase.h:829
Structure Point Contains parameters of Gaussian fits to DMRs.
Definition: DMRtrends.cc:55
T get() const
Definition: EventSetup.h:71
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:9
int numberOfValidPixelHits() const
Definition: HitPattern.h:916
const Point & position() const
position
Definition: BeamSpot.h:62
double dxy() const
dxy parameter. (This is the transverse impact parameter w.r.t. to (0,0,0) ONLY if refPoint is close t...
Definition: TrackBase.h:630
T const * product() const
Definition: ESHandle.h:86