CMS 3D CMS Logo

TrackTfClassifier.cc
Go to the documentation of this file.
2 
8 #include "getBestVertex.h"
9 
13 
14 namespace {
15  class TfDnn {
16  public:
17  TfDnn(const edm::ParameterSet& cfg)
18  : tfDnnLabel_(cfg.getParameter<std::string>("tfDnnLabel")),
19  session_(nullptr)
20 
21  {}
22 
23  static const char* name() { return "TrackTfClassifier"; }
24 
26  desc.add<std::string>("tfDnnLabel", "trackSelectionTf");
27  }
28 
29  void beginStream() {}
30 
31  void initEvent(const edm::EventSetup& es) {
32  if (session_ == nullptr) {
34  es.get<TfGraphRecord>().get(tfDnnLabel_, tfDnnHandle);
35  session_ = tfDnnHandle.product()->getSession();
36  }
37  }
38 
39  float operator()(reco::Track const& trk,
40  reco::BeamSpot const& beamSpot,
41  reco::VertexCollection const& vertices) const {
42  const auto& bestVertex = getBestVertex(trk, vertices);
43 
44  tensorflow::Tensor input1(tensorflow::DT_FLOAT, {1, 29});
45  tensorflow::Tensor input2(tensorflow::DT_FLOAT, {1, 1});
46 
47  input1.matrix<float>()(0, 0) = trk.pt();
48  input1.matrix<float>()(0, 1) = trk.innerMomentum().x();
49  input1.matrix<float>()(0, 2) = trk.innerMomentum().y();
50  input1.matrix<float>()(0, 3) = trk.innerMomentum().z();
51  input1.matrix<float>()(0, 4) = trk.innerMomentum().rho();
52  input1.matrix<float>()(0, 5) = trk.outerMomentum().x();
53  input1.matrix<float>()(0, 6) = trk.outerMomentum().y();
54  input1.matrix<float>()(0, 7) = trk.outerMomentum().z();
55  input1.matrix<float>()(0, 8) = trk.outerMomentum().rho();
56  input1.matrix<float>()(0, 9) = trk.ptError();
57  input1.matrix<float>()(0, 10) = trk.dxy(bestVertex);
58  input1.matrix<float>()(0, 11) = trk.dz(bestVertex);
59  input1.matrix<float>()(0, 12) = trk.dxy(beamSpot.position());
60  input1.matrix<float>()(0, 13) = trk.dz(beamSpot.position());
61  input1.matrix<float>()(0, 14) = trk.dxyError();
62  input1.matrix<float>()(0, 15) = trk.dzError();
63  input1.matrix<float>()(0, 16) = trk.normalizedChi2();
64  input1.matrix<float>()(0, 17) = trk.eta();
65  input1.matrix<float>()(0, 18) = trk.phi();
66  input1.matrix<float>()(0, 19) = trk.etaError();
67  input1.matrix<float>()(0, 20) = trk.phiError();
68  input1.matrix<float>()(0, 21) = trk.hitPattern().numberOfValidPixelHits();
69  input1.matrix<float>()(0, 22) = trk.hitPattern().numberOfValidStripHits();
70  input1.matrix<float>()(0, 23) = trk.ndof();
73  input1.matrix<float>()(0, 26) =
75  input1.matrix<float>()(0, 27) =
78 
79  //Original algo as its own input, it will enter the graph so that it gets one-hot encoded, as is the preferred
80  //format for categorical inputs, where the labels do not have any metric amongst them
81  input2.matrix<float>()(0, 0) = trk.originalAlgo();
82 
83  //The names for the input tensors get locked when freezing the trained tensorflow model. The NamedTensors must
84  //match those names
86  inputs.resize(2);
89  std::vector<tensorflow::Tensor> outputs;
90 
91  //evaluate the input
92  tensorflow::run(const_cast<tensorflow::Session*>(session_), inputs, {"Identity"}, &outputs);
93  //scale output to be [-1, 1] due to convention
94  float output = 2.0 * outputs[0].matrix<float>()(0, 0) - 1.0;
95  return output;
96  }
97 
98  const std::string tfDnnLabel_;
99  const tensorflow::Session* session_;
100  };
101 
102  using TrackTfClassifier = TrackMVAClassifier<TfDnn>;
103 } // namespace
106 
107 DEFINE_FWK_MODULE(TrackTfClassifier);
reco::Track::outerMomentum
const math::XYZVector & outerMomentum() const
momentum vector at the outermost hit position
Definition: Track.h:65
TfGraphDefWrapper.h
edm::ESHandle::product
T const * product() const
Definition: ESHandle.h:86
reco::HitPattern::MISSING_OUTER_HITS
Definition: HitPattern.h:155
TfGraphDefWrapper::getSession
const tensorflow::Session * getSession() const
Definition: TfGraphDefWrapper.cc:5
pwdgSkimBPark_cfi.beamSpot
beamSpot
Definition: pwdgSkimBPark_cfi.py:5
reco::TrackBase::ptError
double ptError() const
error on Pt (set to 1000 TeV if charge==0 for safety)
Definition: TrackBase.h:754
TensorFlow.h
reco::HitPattern::trackerLayersWithoutMeasurement
int trackerLayersWithoutMeasurement(HitCategory category) const
Definition: HitPattern.cc:529
reco::TrackBase::etaError
double etaError() const
error on eta
Definition: TrackBase.h:763
convertSQLitetoXML_cfg.output
output
Definition: convertSQLitetoXML_cfg.py:72
tensorflow::NamedTensor
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:29
reco::VertexCollection
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
PatBasicFWLiteJetAnalyzer_Selector_cfg.outputs
outputs
Definition: PatBasicFWLiteJetAnalyzer_Selector_cfg.py:48
edm::ParameterSetDescription
Definition: ParameterSetDescription.h:52
getBestVertex.h
reco::TrackBase::originalAlgo
TrackAlgorithm originalAlgo() const
Definition: TrackBase.h:548
reco::TrackBase::ndof
double ndof() const
number of degrees of freedom of the fit
Definition: TrackBase.h:590
reco::HitPattern::trackerLayersTotallyOffOrBad
int trackerLayersTotallyOffOrBad(HitCategory category=TRACK_HITS) const
Definition: HitPattern.h:1010
reco::TrackBase::dxyError
double dxyError() const
error on dxy
Definition: TrackBase.h:769
AlignmentTracksFromVertexSelector_cfi.vertices
vertices
Definition: AlignmentTracksFromVertexSelector_cfi.py:5
TfGraphRecord
Definition: TfGraphRecord.h:20
reco::TrackBase::pt
double pt() const
track transverse momentum
Definition: TrackBase.h:637
TrackMVAClassifier.h
MakerMacros.h
reco::Track::innerMomentum
const math::XYZVector & innerMomentum() const
momentum vector at the innermost hit position
Definition: Track.h:59
TrackMVAClassifier
Definition: TrackMVAClassifier.h:95
Track.h
edm::EventSetup::get
T get() const
Definition: EventSetup.h:87
DEFINE_FWK_MODULE
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
fillDescriptions
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
reco::BeamSpot
Definition: BeamSpot.h:21
reco::Track
Definition: Track.h:27
edm::ESHandle
Definition: DTSurvey.h:22
reco::TrackBase::dz
double dz() const
dz parameter (= dsz/cos(lambda)). This is the track z0 w.r.t (0,0,0) only if the refPoint is close to...
Definition: TrackBase.h:622
reco::TrackBase::phi
double phi() const
azimuthal angle of momentum vector
Definition: TrackBase.h:649
edm::ParameterSet
Definition: ParameterSet.h:47
reco::TrackBase::eta
double eta() const
pseudorapidity of momentum vector
Definition: TrackBase.h:652
tensorflow::NamedTensorList
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:30
reco::TrackBase::dzError
double dzError() const
error on dz
Definition: TrackBase.h:778
ModuleDef.h
PixelMapPlotter.inputs
inputs
Definition: PixelMapPlotter.py:490
edm::EventSetup
Definition: EventSetup.h:58
reco::TrackBase::normalizedChi2
double normalizedChi2() const
chi-squared divided by n.d.o.f. (or chi-squared * 1e6 if n.d.o.f. is zero)
Definition: TrackBase.h:593
get
#define get
input2
#define input2
Definition: AMPTWrapper.h:159
AlCaHLTBitMon_QueryRunRegistry.string
string string
Definition: AlCaHLTBitMon_QueryRunRegistry.py:256
reco::HitPattern::TRACK_HITS
Definition: HitPattern.h:155
looper.cfg
cfg
Definition: looper.py:296
reco::TrackBase::hitPattern
const HitPattern & hitPattern() const
Access the hit pattern, indicating in which Tracker layers the track has hits.
Definition: TrackBase.h:504
TfGraphRecord.h
submitPVResolutionJobs.desc
string desc
Definition: submitPVResolutionJobs.py:251
std
Definition: JetResolutionObject.h:76
reco::HitPattern::numberOfValidStripHits
int numberOfValidStripHits() const
Definition: HitPattern.h:830
Vertex.h
reco::HitPattern::MISSING_INNER_HITS
Definition: HitPattern.h:155
tensorflow::run
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:213
Skims_PA_cff.name
name
Definition: Skims_PA_cff.py:17
EventSetup.h
reco::HitPattern::numberOfLostTrackerHits
int numberOfLostTrackerHits(HitCategory category) const
Definition: HitPattern.h:880
reco::HitPattern::numberOfValidPixelHits
int numberOfValidPixelHits() const
Definition: HitPattern.h:818
reco::TrackBase::phiError
double phiError() const
error on phi
Definition: TrackBase.h:766
ConsumesCollector.h
input1
#define input1
Definition: AMPTWrapper.h:139
EDProducer.h
reco::TrackBase::dxy
double dxy() const
dxy parameter. (This is the transverse impact parameter w.r.t. to (0,0,0) ONLY if refPoint is close t...
Definition: TrackBase.h:608
getBestVertex
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:8