CMS 3D CMS Logo

PtAssignmentEngineDxy.cc
Go to the documentation of this file.
2 
3 #include <cassert>
4 #include <iostream>
5 #include <sstream>
6 
7 #include "helper.h" // assert_no_abort
8 
9 PtAssignmentEngineDxy::PtAssignmentEngineDxy() : graphDefDxy_(nullptr), sessionDxy_(nullptr) {}
10 
12  if (sessionDxy_ != nullptr) {
14  }
15  delete graphDefDxy_;
16 }
17 
18 void PtAssignmentEngineDxy::configure(int verbose, const std::string pbFileNameDxy) {
19  verbose_ = verbose;
20 
21  pbFileNameDxy_ = pbFileNameDxy;
22  std::string pbFilePathDxy_ = "L1Trigger/L1TMuon/data/emtf_luts/" + pbFileNameDxy_;
23 
24  inputNameDxy_ = "input1";
25  outputNamesDxy_ = {"Identity"};
26 
27  if (graphDefDxy_ == nullptr) {
29  }
30  emtf_assert(graphDefDxy_ != nullptr);
31 
32  if (sessionDxy_ == nullptr) {
34  }
35 
36  emtf_assert(sessionDxy_ != nullptr);
37 }
38 
41  return instance;
42 }
43 
45  emtf::Feature& feature,
46  emtf::Prediction& prediction) const {
47  // This is called for each track instead of for entire track collection as was done in Phase-2 implementation
48  preprocessing_dxy(track, feature);
49  call_tensorflow_dxy(feature, prediction);
50  return;
51 }
52 
54  // Mimic Phase-1 EMTF input calculations
55  // 6 delta Phis: S1-S2, S1-S3, S1-S4, S2-S3, S2-S4, S3-S4
56  // 6 delta Thetas: S1-S2, S1-S3, S1-S4, S2-S3, S2-S4, S3-S4
57  // 6 delta Phi signs: S1-S2, S1-S3, S1-S4, S2-S3, S2-S4, S3-S4
58  // 6 delta Theta signs: S1-S2, S1-S3, S1-S4, S2-S3, S2-S4, S3-S4
59  // 1 track Theta taken from stub coordinate in ME2, ME3, ME4 (in this priority)
60  // 4 CSC pattern values (Run 2 convention): S1, S2, S3, S4
61  // Total: 29 variables
62  std::array<float, 6> x_dphi;
63  std::array<float, 6> x_dphi_sign;
64  std::array<float, 6> x_dtheta;
65  std::array<float, 6> x_dtheta_sign;
66  std::array<float, 1> x_trk_theta;
67  std::array<float, 4> x_csc_pattern;
68 
69  // Initialize x_csc_pattern to zeros
70  x_csc_pattern.fill(0);
71 
72  EMTFPtLUT data = track.PtLUT();
73 
74  const int invalid_dtheta = 127;
75  const int invalid_dphi = 8191;
76 
77  // // Which stations have hits
78  bool st1 = (track.Mode() >= 8);
79  bool st2 = ((track.Mode() % 8) >= 4);
80  bool st3 = ((track.Mode() % 4) >= 2);
81  bool st4 = ((track.Mode() % 2) == 1);
82 
83  // Get valid pattern values
84  if (st1)
85  x_csc_pattern[0] = data.cpattern[0];
86  if (st2)
87  x_csc_pattern[1] = data.cpattern[1];
88  if (st3)
89  x_csc_pattern[2] = data.cpattern[2];
90  if (st4)
91  x_csc_pattern[3] = data.cpattern[3];
92 
93  for (int i = 0; i < 6; ++i) { // There are 6 deltas between 4 stations.
94  // Calculate delta phi
95  x_dphi[i] = (data.delta_ph[i] != invalid_dphi) ? data.delta_ph[i] : 0;
96 
97  // Calculate delta theta
98  x_dtheta[i] = (data.delta_th[i] != invalid_dtheta) ? data.delta_th[i] : 0;
99 
100  // Get delta phi and theta signs
101  x_dphi_sign[i] = data.sign_ph[i];
102  x_dtheta_sign[i] = data.sign_th[i];
103  }
104 
105  // Set dPhi and dTheta values to 0 if there was no hit in the station
106  if (!st1) {
107  x_dphi[0] = 0;
108  x_dphi[1] = 0;
109  x_dphi[2] = 0;
110 
111  x_dtheta[0] = 0;
112  x_dtheta[1] = 0;
113  x_dtheta[2] = 0;
114  }
115  if (!st2) {
116  x_dphi[0] = 0;
117  x_dphi[3] = 0;
118  x_dphi[4] = 0;
119 
120  x_dtheta[0] = 0;
121  x_dtheta[3] = 0;
122  x_dtheta[4] = 0;
123  }
124  if (!st3) {
125  x_dphi[1] = 0;
126  x_dphi[3] = 0;
127  x_dphi[5] = 0;
128 
129  x_dtheta[1] = 0;
130  x_dtheta[3] = 0;
131  x_dtheta[5] = 0;
132  }
133  if (!st4) {
134  x_dphi[2] = 0;
135  x_dphi[4] = 0;
136  x_dphi[5] = 0;
137 
138  x_dtheta[2] = 0;
139  x_dtheta[4] = 0;
140  x_dtheta[5] = 0;
141  }
142 
143  x_trk_theta[0] = track.Theta_fp();
144 
145  // Set NN inputs
146  feature = {{x_dphi[0], x_dphi[1], x_dphi[2], x_dphi[3], x_dphi[4],
147  x_dphi[5], x_dphi_sign[0], x_dphi_sign[1], x_dphi_sign[2], x_dphi_sign[3],
148  x_dphi_sign[4], x_dphi_sign[5], x_dtheta[0], x_dtheta[1], x_dtheta[2],
149  x_dtheta[3], x_dtheta[4], x_dtheta[5], x_dtheta_sign[0], x_dtheta_sign[1],
150  x_dtheta_sign[2], x_dtheta_sign[3], x_dtheta_sign[4], x_dtheta_sign[5], x_csc_pattern[0],
151  x_csc_pattern[1], x_csc_pattern[2], x_csc_pattern[3], x_trk_theta[0]}};
152 
153  return;
154 }
155 
157  tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, emtf::NUM_FEATURES});
158  std::vector<tensorflow::Tensor> outputs;
159  emtf_assert(feature.size() == emtf::NUM_FEATURES);
160 
161  float* d = input.flat<float>().data();
162  std::copy(feature.begin(), feature.end(), d);
164  emtf_assert(outputs.size() == 1);
165  emtf_assert(prediction.size() == emtf::NUM_PREDICTIONS);
166 
167  prediction.at(0) = outputs[0].matrix<float>()(0, 0);
168  prediction.at(1) = outputs[0].matrix<float>()(0, 1);
169 
170  return;
171 }
static PFTauRenderPlugin instance
bool verbose
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:129
std::vector< std::string > outputNamesDxy_
const PtAssignmentEngineAux2017 & aux() const
virtual void calculate_pt_dxy(const EMTFTrack &track, emtf::Feature &feature, emtf::Prediction &prediction) const
constexpr int NUM_PREDICTIONS
Definition: Common.h:69
tensorflow::Session * sessionDxy_
void configure(int verbose, const std::string pbFileNameDxy)
static std::string const input
Definition: EdmProvDump.cc:50
virtual void call_tensorflow_dxy(const emtf::Feature &feature, emtf::Prediction &prediction) const
tensorflow::GraphDef * graphDefDxy_
std::array< float, NUM_PREDICTIONS > Prediction
Definition: Common.h:72
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:281
bool closeSession(Session *&session)
Definition: TensorFlow.cc:243
d
Definition: ztail.py:151
Session * createSession()
Definition: TensorFlow.cc:146
#define emtf_assert(expr)
Definition: DebugTools.h:18
constexpr int NUM_FEATURES
Definition: Common.h:68
char data[epos_bytes_allocation]
Definition: EPOS_Wrapper.h:80
virtual void preprocessing_dxy(const EMTFTrack &track, emtf::Feature &feature) const
std::array< float, NUM_FEATURES > Feature
Definition: Common.h:71