CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
TrackTfClassifier.cc
Go to the documentation of this file.
2 
9 
13 
14 namespace {
15  class TfDnn {
16  public:
18  : tfDnnLabel_(cfg.getParameter<std::string>("tfDnnLabel")),
19  tfDnnToken_(iC.esConsumes(edm::ESInputTag("", tfDnnLabel_))),
20  session_(nullptr)
21 
22  {}
23 
24  static const char* name() { return "trackTfClassifierDefault"; }
25 
27  desc.add<std::string>("tfDnnLabel", "trackSelectionTf");
28  }
29  void beginStream() {}
30 
31  void initEvent(const edm::EventSetup& es) {
32  if (session_ == nullptr) {
33  session_ = es.getData(tfDnnToken_).getSession();
34  }
35  }
36 
37  float operator()(reco::Track const& trk,
38  reco::BeamSpot const& beamSpot,
39  reco::VertexCollection const& vertices) const {
40  const auto& bestVertex = getBestVertex(trk, vertices);
41 
42  tensorflow::Tensor input1(tensorflow::DT_FLOAT, {1, 29});
43  tensorflow::Tensor input2(tensorflow::DT_FLOAT, {1, 1});
44 
45  input1.matrix<float>()(0, 0) = trk.pt();
46  input1.matrix<float>()(0, 1) = trk.innerMomentum().x();
47  input1.matrix<float>()(0, 2) = trk.innerMomentum().y();
48  input1.matrix<float>()(0, 3) = trk.innerMomentum().z();
49  input1.matrix<float>()(0, 4) = trk.innerMomentum().rho();
50  input1.matrix<float>()(0, 5) = trk.outerMomentum().x();
51  input1.matrix<float>()(0, 6) = trk.outerMomentum().y();
52  input1.matrix<float>()(0, 7) = trk.outerMomentum().z();
53  input1.matrix<float>()(0, 8) = trk.outerMomentum().rho();
54  input1.matrix<float>()(0, 9) = trk.ptError();
55  input1.matrix<float>()(0, 10) = trk.dxy(bestVertex);
56  input1.matrix<float>()(0, 11) = trk.dz(bestVertex);
57  input1.matrix<float>()(0, 12) = trk.dxy(beamSpot.position());
58  input1.matrix<float>()(0, 13) = trk.dz(beamSpot.position());
59  input1.matrix<float>()(0, 14) = trk.dxyError();
60  input1.matrix<float>()(0, 15) = trk.dzError();
61  input1.matrix<float>()(0, 16) = trk.normalizedChi2();
62  input1.matrix<float>()(0, 17) = trk.eta();
63  input1.matrix<float>()(0, 18) = trk.phi();
64  input1.matrix<float>()(0, 19) = trk.etaError();
65  input1.matrix<float>()(0, 20) = trk.phiError();
66  input1.matrix<float>()(0, 21) = trk.hitPattern().numberOfValidPixelHits();
67  input1.matrix<float>()(0, 22) = trk.hitPattern().numberOfValidStripHits();
68  input1.matrix<float>()(0, 23) = trk.ndof();
71  input1.matrix<float>()(0, 26) =
73  input1.matrix<float>()(0, 27) =
76 
77  //Original algo as its own input, it will enter the graph so that it gets one-hot encoded, as is the preferred
78  //format for categorical inputs, where the labels do not have any metric amongst them
79  input2.matrix<float>()(0, 0) = trk.originalAlgo();
80 
81  //The names for the input tensors get locked when freezing the trained tensorflow model. The NamedTensors must
82  //match those names
84  inputs.resize(2);
85  inputs[0] = tensorflow::NamedTensor("x", input1);
86  inputs[1] = tensorflow::NamedTensor("y", input2);
87  std::vector<tensorflow::Tensor> outputs;
88 
89  //evaluate the input
90  tensorflow::run(const_cast<tensorflow::Session*>(session_), inputs, {"Identity"}, &outputs);
91  //scale output to be [-1, 1] due to convention
92  float output = 2.0 * outputs[0].matrix<float>()(0, 0) - 1.0;
93  return output;
94  }
95 
96  const std::string tfDnnLabel_;
98  const tensorflow::Session* session_;
99  };
100 
101  using TrackTfClassifier = TrackMVAClassifier<TfDnn>;
102 } // namespace
105 
106 DEFINE_FWK_MODULE(TrackTfClassifier);
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:30
tuple cfg
Definition: looper.py:296
double normalizedChi2() const
chi-squared divided by n.d.o.f. (or chi-squared * 1e6 if n.d.o.f. is zero)
Definition: TrackBase.h:593
double dxyError() const
error on dxy
Definition: TrackBase.h:769
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
double etaError() const
error on eta
Definition: TrackBase.h:763
int numberOfValidStripHits() const
Definition: HitPattern.h:843
double phi() const
azimuthal angle of momentum vector
Definition: TrackBase.h:649
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
#define input2
Definition: AMPTWrapper.h:159
int numberOfLostTrackerHits(HitCategory category) const
Definition: HitPattern.h:893
bool getData(T &iHolder) const
Definition: EventSetup.h:122
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:29
double eta() const
pseudorapidity of momentum vector
Definition: TrackBase.h:652
double ndof() const
number of degrees of freedom of the fit
Definition: TrackBase.h:590
double pt() const
track transverse momentum
Definition: TrackBase.h:637
double ptError() const
error on Pt (set to 1000 TeV if charge==0 for safety)
Definition: TrackBase.h:754
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:213
double phiError() const
error on phi
Definition: TrackBase.h:766
#define input1
Definition: AMPTWrapper.h:139
ParameterDescriptionBase * add(U const &iLabel, T const &value)
int trackerLayersTotallyOffOrBad(HitCategory category=TRACK_HITS) const
Definition: HitPattern.h:1023
double dz() const
dz parameter (= dsz/cos(lambda)). This is the track z0 w.r.t (0,0,0) only if the refPoint is close to...
Definition: TrackBase.h:622
double dzError() const
error on dz
Definition: TrackBase.h:778
TrackAlgorithm originalAlgo() const
Definition: TrackBase.h:548
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
const math::XYZVector & outerMomentum() const
momentum vector at the outermost hit position
Definition: Track.h:65
const HitPattern & hitPattern() const
Access the hit pattern, indicating in which Tracker layers the track has hits.
Definition: TrackBase.h:504
int trackerLayersWithoutMeasurement(HitCategory category) const
Definition: HitPattern.cc:553
const math::XYZVector & innerMomentum() const
momentum vector at the innermost hit position
Definition: Track.h:59
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:8
int numberOfValidPixelHits() const
Definition: HitPattern.h:831
const Point & position() const
position
Definition: BeamSpot.h:59
double dxy() const
dxy parameter. (This is the transverse impact parameter w.r.t. to (0,0,0) ONLY if refPoint is close t...
Definition: TrackBase.h:608
ESGetTokenH3DDVariant esConsumes(std::string const &Reccord, edm::ConsumesCollector &)
Definition: DeDxTools.cc:283