CMS 3D CMS Logo

TrackTfClassifier.cc
Go to the documentation of this file.
2 
9 
13 
14 namespace {
15  class TfDnn {
16  public:
18  : tfDnnLabel_(cfg.getParameter<std::string>("tfDnnLabel")),
19  tfDnnToken_(iC.esConsumes(edm::ESInputTag("", tfDnnLabel_))),
20  session_(nullptr),
21  bsize_(cfg.getParameter<int>("batchSize"))
22 
23  {}
24 
25  static const char* name() { return "trackTfClassifierDefault"; }
26 
28  desc.add<std::string>("tfDnnLabel", "trackSelectionTf");
29  desc.add<int>("batchSize", 16);
30  }
31  void beginStream() {}
32 
33  void initEvent(const edm::EventSetup& es) {
34  if (session_ == nullptr) {
35  session_ = es.getData(tfDnnToken_).getSession();
36  }
37  }
38 
39  std::vector<float> operator()(reco::TrackCollection const& tracks,
40  reco::BeamSpot const& beamSpot,
41  reco::VertexCollection const& vertices) const {
42  int size_in = (int)tracks.size();
43  int nbatches = size_in / bsize_;
44 
45  std::vector<float> output;
46  output.resize(size_in);
47 
48  tensorflow::Tensor input1(tensorflow::DT_FLOAT, {bsize_, 29});
49  tensorflow::Tensor input2(tensorflow::DT_FLOAT, {bsize_, 1});
50 
51  for (auto nb = 0; nb < nbatches + 1; nb++) {
52  for (auto nt = 0; nt < bsize_; nt++) {
53  int itrack = nt + bsize_ * nb;
54  if (itrack >= size_in)
55  continue;
56  const auto& trk = tracks[itrack];
57 
58  const auto& bestVertex = getBestVertex(trk, vertices);
59 
60  input1.matrix<float>()(nt, 0) = trk.pt();
61  input1.matrix<float>()(nt, 1) = trk.innerMomentum().x();
62  input1.matrix<float>()(nt, 2) = trk.innerMomentum().y();
63  input1.matrix<float>()(nt, 3) = trk.innerMomentum().z();
64  input1.matrix<float>()(nt, 4) = trk.innerMomentum().rho();
65  input1.matrix<float>()(nt, 5) = trk.outerMomentum().x();
66  input1.matrix<float>()(nt, 6) = trk.outerMomentum().y();
67  input1.matrix<float>()(nt, 7) = trk.outerMomentum().z();
68  input1.matrix<float>()(nt, 8) = trk.outerMomentum().rho();
69  input1.matrix<float>()(nt, 9) = trk.ptError();
70  input1.matrix<float>()(nt, 10) = trk.dxy(bestVertex);
71  input1.matrix<float>()(nt, 11) = trk.dz(bestVertex);
72  input1.matrix<float>()(nt, 12) = trk.dxy(beamSpot.position());
73  input1.matrix<float>()(nt, 13) = trk.dz(beamSpot.position());
74  input1.matrix<float>()(nt, 14) = trk.dxyError();
75  input1.matrix<float>()(nt, 15) = trk.dzError();
76  input1.matrix<float>()(nt, 16) = trk.normalizedChi2();
77  input1.matrix<float>()(nt, 17) = trk.eta();
78  input1.matrix<float>()(nt, 18) = trk.phi();
79  input1.matrix<float>()(nt, 19) = trk.etaError();
80  input1.matrix<float>()(nt, 20) = trk.phiError();
81  input1.matrix<float>()(nt, 21) = trk.hitPattern().numberOfValidPixelHits();
82  input1.matrix<float>()(nt, 22) = trk.hitPattern().numberOfValidStripHits();
83  input1.matrix<float>()(nt, 23) = trk.ndof();
84  input1.matrix<float>()(nt, 24) =
85  trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_INNER_HITS);
86  input1.matrix<float>()(nt, 25) =
87  trk.hitPattern().numberOfLostTrackerHits(reco::HitPattern::MISSING_OUTER_HITS);
88  input1.matrix<float>()(nt, 26) =
89  trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_INNER_HITS);
90  input1.matrix<float>()(nt, 27) =
91  trk.hitPattern().trackerLayersTotallyOffOrBad(reco::HitPattern::MISSING_OUTER_HITS);
92  input1.matrix<float>()(nt, 28) =
93  trk.hitPattern().trackerLayersWithoutMeasurement(reco::HitPattern::TRACK_HITS);
94 
95  //Original algo as its own input, it will enter the graph so that it gets one-hot encoded, as is the preferred
96  //format for categorical inputs, where the labels do not have any metric amongst them
97  input2.matrix<float>()(nt, 0) = trk.originalAlgo();
98  }
99 
100  //The names for the input tensors get locked when freezing the trained tensorflow model. The NamedTensors must
101  //match those names
103  inputs.resize(2);
106  std::vector<tensorflow::Tensor> outputs;
107 
108  //evaluate the input
109  tensorflow::run(const_cast<tensorflow::Session*>(session_), inputs, {"Identity"}, &outputs);
110 
111  for (auto nt = 0; nt < bsize_; nt++) {
112  int itrack = nt + bsize_ * nb;
113  if (itrack >= size_in)
114  continue;
115  float out0 = 2.0 * outputs[0].matrix<float>()(nt, 0) - 1.0;
116  output[itrack] = out0;
117  }
118  }
119  return output;
120  }
121 
122  const std::string tfDnnLabel_;
124  const tensorflow::Session* session_;
125  const int bsize_;
126  };
127 } // namespace
128 
129 template <>
132  reco::BeamSpot const& beamSpot,
135  const auto& scores = mva(tracks, beamSpot, vertices);
136  size_t current = 0;
137 
138  for (auto score : scores) {
139  std::pair<float, bool> output(score, true);
140  mvas[current++] = output;
141  }
142 }
143 
144 namespace {
145  using TrackTfClassifier = TrackMVAClassifier<TfDnn>;
146 } // namespace
147 
150 
151 DEFINE_FWK_MODULE(TrackTfClassifier);
ESGetTokenH3DDVariant esConsumes(std::string const &Record, edm::ConsumesCollector &)
Definition: DeDxTools.cc:283
std::vector< NamedTensor > NamedTensorList
Definition: TensorFlow.h:30
std::vector< std::pair< float, bool > > MVAPairCollection
std::vector< Track > TrackCollection
collection of Tracks
Definition: TrackFwd.h:14
std::vector< Vertex > VertexCollection
collection of Vertex objects
Definition: VertexFwd.h:9
#define input2
Definition: AMPTWrapper.h:159
void operator()(MVA const &mva, reco::TrackCollection const &tracks, reco::BeamSpot const &beamSpot, reco::VertexCollection const &vertices, TrackMVAClassifierBase::MVAPairCollection &mvas)
std::pair< std::string, Tensor > NamedTensor
Definition: TensorFlow.h:29
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:213
#define input1
Definition: AMPTWrapper.h:139
#define DEFINE_FWK_MODULE(type)
Definition: MakerMacros.h:16
bool getData(T &iHolder) const
Definition: EventSetup.h:122
int nt
Definition: AMPTWrapper.h:42
auto const & tracks
cannot be loose
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions)
HLT enums.
Point getBestVertex(reco::Track const &trk, reco::VertexCollection const &vertices, const size_t minNtracks=2)
Definition: getBestVertex.h:8