CMS 3D CMS Logo

TrackLwtnnClassifier.cc
Go to the documentation of this file.
2 
5 
8 
10 
11 //from lwtnn
13 #include "lwtnn/LightweightNeuralNetwork.hh"
14 
15 namespace {
16  struct lwtnn {
18  : lwtnnLabel_(cfg.getParameter<std::string>("lwtnnLabel")),
19  lwtnnToken_(iC.esConsumes(edm::ESInputTag("", lwtnnLabel_))) {}
20 
21  static const char* name() { return "TrackLwtnnClassifier"; }
22 
24  desc.add<std::string>("lwtnnLabel", "trackSelectionLwtnn");
25  }
26 
27  void beginStream() {}
28  void initEvent(const edm::EventSetup& es) { neuralNetwork_ = &es.getData(lwtnnToken_); }
29 
30  std::pair<float, bool> operator()(reco::Track const& trk,
31  reco::BeamSpot const& beamSpot,
33  lwt::ValueMap& inputs) const {
34  // lwt::ValueMap is typedef for std::map<std::string, double>
35  //
36  // It is cached per event to avoid constructing the map for each
37  // track while keeping the operator() interface const.
38 
39  Point bestVertex = getBestVertex(trk, vertices);
40 
41  inputs["trk_pt"] = trk.pt();
42  inputs["trk_eta"] = trk.eta();
43  inputs["trk_lambda"] = trk.lambda();
44  inputs["trk_dxy"] = trk.dxy(beamSpot.position()); // Training done without taking absolute value
45  inputs["trk_dz"] = trk.dz(beamSpot.position()); // Training done without taking absolute value
46  inputs["trk_dxyClosestPV"] = trk.dxy(bestVertex); // Training done without taking absolute value
47  // Training done without taking absolute value
48  inputs["trk_dzClosestPVNorm"] = std::max(-0.2, std::min(trk.dz(bestVertex), 0.2));
49  inputs["trk_ptErr"] = trk.ptError();
50  inputs["trk_etaErr"] = trk.etaError();
51  inputs["trk_lambdaErr"] = trk.lambdaError();
52  inputs["trk_dxyErr"] = trk.dxyError();
53  inputs["trk_dzErr"] = trk.dzError();
54  inputs["trk_nChi2"] = trk.normalizedChi2();
55  inputs["trk_ndof"] = trk.ndof();
57  inputs["trk_nPixel"] = trk.hitPattern().numberOfValidPixelHits();
58  inputs["trk_nStrip"] = trk.hitPattern().numberOfValidStripHits();
59  inputs["trk_nPixelLay"] = trk.hitPattern().pixelLayersWithMeasurement();
60  inputs["trk_nStripLay"] = trk.hitPattern().stripLayersWithMeasurement();
64  inputs["trk_algo"] = trk.algo();
65 
66  auto out = neuralNetwork_->compute(inputs);
67  // there should only one output
68  if (out.size() != 1)
69  throw cms::Exception("LogicError") << "Expecting exactly one output from NN, got " << out.size();
70 
71  float output = 2.0 * out.begin()->second - 1.0;
72 
73  //Check if the network is known to be unreliable in that part of phase space. Hard cut values
74  //correspond to rare tracks known to be difficult for the Deep Neural Network classifier
75 
76  bool isReliable = true;
77  //T1qqqq
78  if (std::abs(inputs["trk_dxy"]) >= 0.1 && inputs["trk_etaErr"] < 0.003 && inputs["trk_dxyErr"] < 0.03 &&
79  inputs["trk_ndof"] > 3) {
80  isReliable = false;
81  }
82  //T5qqqqLL
83  if ((inputs["trk_pt"] > 100.0) && (inputs["trk_nChi2"] < 4.0) && (inputs["trk_etaErr"] < 0.001)) {
84  isReliable = false;
85  }
86 
87  std::pair<float, bool> return_(output, isReliable);
88  return return_;
89  }
90 
91  std::string lwtnnLabel_;
93  const lwt::LightweightNeuralNetwork* neuralNetwork_;
94  };
95 
96  using TrackLwtnnClassifier = TrackMVAClassifier<lwtnn, lwt::ValueMap>;
97 } // namespace
98 
101 
102 DEFINE_FWK_MODULE(TrackLwtnnClassifier);
int numberOfValidPixelHits() const
Definition: HitPattern.h:831
ESGetTokenH3DDVariant esConsumes(std::string const &Record, edm::ConsumesCollector &)
Definition: DeDxTools.cc:283
int numberOfLostHits(HitCategory category) const
Definition: HitPattern.h:891
double lambda() const
Lambda angle.
Definition: TrackBase.h:605
double ptError() const
error on Pt (set to 1000 TeV if charge==0 for safety)
Definition: TrackBase.h:754
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
int stripLayersWithMeasurement() const
Definition: HitPattern.h:1005
double lambdaError() const
error on lambda
Definition: TrackBase.h:760
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
int numberOfValidStripHits() const
Definition: HitPattern.h:843
double pt() const
track transverse momentum
Definition: TrackBase.h:637
double ndof() const
number of degrees of freedom of the fit
Definition: TrackBase.h:590
int numberOfValidStripLayersWithMonoAndStereo(uint16_t stripdet, uint16_t layer) const
Definition: HitPattern.cc:369
double dz() const
dz parameter (= dsz/cos(lambda)). This is the track z0 w.r.t (0,0,0) only if the refPoint is close to...
Definition: TrackBase.h:622
double dxyError() const
error on dxy
Definition: TrackBase.h:769
double dzError() const
error on dz
Definition: TrackBase.h:778
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
bool getData(T &iHolder) const
Definition: EventSetup.h:122
double eta() const
pseudorapidity of momentum vector
Definition: TrackBase.h:652
TrackAlgorithm algo() const
Definition: TrackBase.h:547
int trackerLayersWithoutMeasurement(HitCategory category) const
Definition: HitPattern.cc:553
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
const HitPattern & hitPattern() const
Access the hit pattern, indicating in which Tracker layers the track has hits.
Definition: TrackBase.h:504
double normalizedChi2() const
chi-squared divided by n.d.o.f. (or chi-squared * 1e6 if n.d.o.f. is zero)
Definition: TrackBase.h:593
HLT enums.
Structure Point Contains parameters of Gaussian fits to DMRs.
Definition: DMRtrends.cc:57
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:8
int pixelLayersWithMeasurement() const
Definition: HitPattern.cc:513
double etaError() const
error on eta
Definition: TrackBase.h:763
double dxy() const
dxy parameter. (This is the transverse impact parameter w.r.t. to (0,0,0) ONLY if refPoint is close t...
Definition: TrackBase.h:608