CMS 3D CMS Logo

 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
PtAssignmentEngineDxy.cc
Go to the documentation of this file.
2 
3 #include <cassert>
4 #include <iostream>
5 #include <sstream>
6 
7 #include "helper.h" // assert_no_abort
8 
9 PtAssignmentEngineDxy::PtAssignmentEngineDxy() : graphDefDxy_(nullptr), sessionDxy_(nullptr) {}
10 
12  if (sessionDxy_ != nullptr) {
14  }
15  delete graphDefDxy_;
16 }
17 
18 void PtAssignmentEngineDxy::configure(int verbose, const std::string pbFileNameDxy) {
19  verbose_ = verbose;
20 
21  pbFileNameDxy_ = pbFileNameDxy;
22  std::string pbFilePathDxy_ = "L1Trigger/L1TMuon/data/emtf_luts/" + pbFileNameDxy_;
23 
24  inputNameDxy_ = "batch_normalization_1_input";
25  outputNamesDxy_ = {"dense_4/BiasAdd"};
26 
27  if (graphDefDxy_ == nullptr) {
29  }
30  emtf_assert(graphDefDxy_ != nullptr);
31 
32  if (sessionDxy_ == nullptr) {
34  }
35 
36  emtf_assert(sessionDxy_ != nullptr);
37 }
38 
41  return instance;
42 }
43 
45  emtf::Feature& feature,
46  emtf::Prediction& prediction) const {
47  // This is called for each track instead of for entire track collection as was done in Phase-2 implementation
48  preprocessing_dxy(track, feature);
49  call_tensorflow_dxy(feature, prediction);
50  return;
51 }
52 
54  // Mimic Phase-1 EMTF input calculations
55  // 6 delta Phis: S1-S2, S1-S3, S1-S4, S2-S3, S2-S4, S3-S4
56  // 6 delta Thetas: S1-S2, S1-S3, S1-S4, S2-S3, S2-S4, S3-S4
57  // 4 bends : set to zero if no CSC hit and thus RPC hit is used
58  // 1 FR bit: for ME1 only
59  // 1 Ring bit: for ME1 only
60  // 1 track Theta taken from stub coordinate in ME2, ME3, ME4 (in this priority)
61  // 4 RPC bits indicating if ME or RE hit was used in each station (S1, S2, S3, S4)
62  // Total: 23 variables
63  std::array<float, 6> x_dphi;
64  std::array<float, 6> x_dtheta;
65  std::array<float, 4> x_bend_emtf;
66  std::array<float, 1> x_fr_emtf;
67  std::array<float, 1> x_trk_theta;
68  std::array<float, 1> x_me11ring;
69  std::array<float, 4> x_rpcbit;
70 
71  // Initialize to zeros
72  x_dphi.fill(0);
73  x_dtheta.fill(0);
74  //
75  x_bend_emtf.fill(0);
76  x_fr_emtf.fill(0);
77  x_trk_theta.fill(0);
78  x_me11ring.fill(0);
79  x_rpcbit.fill(0);
80 
81  EMTFPtLUT data = track.PtLUT();
82 
83  const int invalid_dtheta = 127;
84  const int invalid_dphi = 8191;
85 
86  // // Variables to extract from the PtLUT
87  int dPhi_12, dPhi_13, dPhi_14, dPhi_23, dPhi_24, dPhi_34;
88  int dTh_12, dTh_13, dTh_14, dTh_23, dTh_24, dTh_34;
89  int fr_1;
90  int bend_1, bend_2, bend_3, bend_4;
91  int rpc_1, rpc_2, rpc_3, rpc_4;
92  int St1_ring2 = data.st1_ring2;
93 
94  int pat1 = -99, pat2 = -99, pat3 = -99, pat4 = -99;
95 
96  // // Which stations have hits
97  int st1 = (track.Mode() >= 8);
98  int st2 = ((track.Mode() % 8) >= 4);
99  int st3 = ((track.Mode() % 4) >= 2);
100  int st4 = ((track.Mode() % 2) == 1);
101 
102  // Get valid pattern values
103  if (st1)
104  pat1 = data.cpattern[0];
105  if (st2)
106  pat2 = data.cpattern[1];
107  if (st3)
108  pat3 = data.cpattern[2];
109  if (st4)
110  pat4 = data.cpattern[3];
111 
112  // F/R bit
113  fr_1 = data.fr[0];
114 
115  // RPC hit in station
116  rpc_1 = (st1 ? (pat1 == 0) : 0);
117  rpc_2 = (st2 ? (pat2 == 0) : 0);
118  rpc_3 = (st3 ? (pat3 == 0) : 0);
119  rpc_4 = (st4 ? (pat4 == 0) : 0);
120 
121  // Calculate bends from patterns
122  bend_1 = aux().calcBendFromPattern(pat1, track.Endcap());
123  bend_2 = aux().calcBendFromPattern(pat2, track.Endcap());
124  bend_3 = aux().calcBendFromPattern(pat3, track.Endcap());
125  bend_4 = aux().calcBendFromPattern(pat4, track.Endcap());
126 
127  // Invalid bend value is 0 in the NN
128  if (bend_1 == -99)
129  bend_1 = 0;
130  if (bend_2 == -99)
131  bend_2 = 0;
132  if (bend_3 == -99)
133  bend_3 = 0;
134  if (bend_4 == -99)
135  bend_4 = 0;
136 
137  // In the emulator RPCs get assigned abs(bend) = 5. This needs to be 0 for the NN.
138  if (std::abs(bend_1) == 5 && rpc_1 == 1)
139  bend_1 = 0;
140  if (std::abs(bend_2) == 5 && rpc_2 == 1)
141  bend_2 = 0;
142  if (std::abs(bend_3) == 5 && rpc_3 == 1)
143  bend_3 = 0;
144  if (std::abs(bend_4) == 5 && rpc_4 == 1)
145  bend_4 = 0;
146 
147  // Calculate delta phi
148  dPhi_12 = (data.delta_ph[0] != invalid_dphi) ? data.delta_ph[0] * (data.sign_ph[0] ? 1 : -1) : 0;
149  dPhi_13 = (data.delta_ph[1] != invalid_dphi) ? data.delta_ph[1] * (data.sign_ph[1] ? 1 : -1) : 0;
150  dPhi_14 = (data.delta_ph[2] != invalid_dphi) ? data.delta_ph[2] * (data.sign_ph[2] ? 1 : -1) : 0;
151  dPhi_23 = (data.delta_ph[3] != invalid_dphi) ? data.delta_ph[3] * (data.sign_ph[3] ? 1 : -1) : 0;
152  dPhi_24 = (data.delta_ph[4] != invalid_dphi) ? data.delta_ph[4] * (data.sign_ph[4] ? 1 : -1) : 0;
153  dPhi_34 = (data.delta_ph[5] != invalid_dphi) ? data.delta_ph[5] * (data.sign_ph[5] ? 1 : -1) : 0;
154 
155  // Calculate delta theta
156  dTh_12 = (data.delta_th[0] != invalid_dtheta) ? data.delta_th[0] * (data.sign_th[0] ? 1 : -1) : 0;
157  dTh_13 = (data.delta_th[1] != invalid_dtheta) ? data.delta_th[1] * (data.sign_th[1] ? 1 : -1) : 0;
158  dTh_14 = (data.delta_th[2] != invalid_dtheta) ? data.delta_th[2] * (data.sign_th[2] ? 1 : -1) : 0;
159  dTh_23 = (data.delta_th[3] != invalid_dtheta) ? data.delta_th[3] * (data.sign_th[3] ? 1 : -1) : 0;
160  dTh_24 = (data.delta_th[4] != invalid_dtheta) ? data.delta_th[4] * (data.sign_th[4] ? 1 : -1) : 0;
161  dTh_34 = (data.delta_th[5] != invalid_dtheta) ? data.delta_th[5] * (data.sign_th[5] ? 1 : -1) : 0;
162 
163  // Set dPhi and dTheta values to 0 if there was no hit in the station
164  if (!st1) {
165  dPhi_12 = 0;
166  dPhi_13 = 0;
167  dPhi_14 = 0;
168 
169  dTh_12 = 0;
170  dTh_13 = 0;
171  dTh_14 = 0;
172  }
173  if (!st2) {
174  dPhi_12 = 0;
175  dPhi_23 = 0;
176  dPhi_24 = 0;
177 
178  dTh_12 = 0;
179  dTh_23 = 0;
180  dTh_24 = 0;
181  }
182  if (!st3) {
183  dPhi_13 = 0;
184  dPhi_23 = 0;
185  dPhi_34 = 0;
186 
187  dTh_13 = 0;
188  dTh_23 = 0;
189  dTh_34 = 0;
190  }
191  if (!st4) {
192  dPhi_14 = 0;
193  dPhi_24 = 0;
194  dPhi_34 = 0;
195 
196  dTh_14 = 0;
197  dTh_24 = 0;
198  dTh_34 = 0;
199  }
200 
201  // Set NN inputs
202 
203  // NN was trained with the wrong sign convention. TO BE CHANGED LATER!
204  x_dphi[0] = dPhi_12;
205  x_dphi[1] = dPhi_13;
206  x_dphi[2] = dPhi_14;
207  x_dphi[3] = dPhi_23;
208  x_dphi[4] = dPhi_24;
209  x_dphi[5] = dPhi_34;
210 
211  // NN was trained with the wrong sign convention. TO BE CHANGED LATER!
212  x_dtheta[0] = dTh_12;
213  x_dtheta[1] = dTh_13;
214  x_dtheta[2] = dTh_14;
215  x_dtheta[3] = dTh_23;
216  x_dtheta[4] = dTh_24;
217  x_dtheta[5] = dTh_34;
218 
219  // NN was trained with the wrong sign convention. TO BE CHANGED LATER!
220  x_bend_emtf[0] = bend_1;
221  x_bend_emtf[1] = bend_2;
222  x_bend_emtf[2] = bend_3;
223  x_bend_emtf[3] = bend_4;
224 
225  x_fr_emtf[0] = fr_1;
226  x_trk_theta[0] = track.Theta_fp();
227  x_me11ring[0] = St1_ring2;
228 
229  x_rpcbit[0] = rpc_1;
230  x_rpcbit[1] = rpc_2;
231  x_rpcbit[2] = rpc_3;
232  x_rpcbit[3] = rpc_4;
233 
234  feature = {{x_dphi[0], x_dphi[1], x_dphi[2], x_dphi[3], x_dphi[4], x_dphi[5],
235  x_dtheta[0], x_dtheta[1], x_dtheta[2], x_dtheta[3], x_dtheta[4], x_dtheta[5],
236  x_bend_emtf[0], x_bend_emtf[1], x_bend_emtf[2], x_bend_emtf[3], x_fr_emtf[0], x_trk_theta[0],
237  x_me11ring[0], x_rpcbit[0], x_rpcbit[1], x_rpcbit[2], x_rpcbit[3]}};
238  return;
239 }
240 
242  tensorflow::Tensor input(tensorflow::DT_FLOAT, {1, emtf::NUM_FEATURES});
243  std::vector<tensorflow::Tensor> outputs;
244  emtf_assert(feature.size() == emtf::NUM_FEATURES);
245 
246  float* d = input.flat<float>().data();
247  std::copy(feature.begin(), feature.end(), d);
249  emtf_assert(outputs.size() == 1);
250  emtf_assert(prediction.size() == emtf::NUM_PREDICTIONS);
251 
252  const float reg_pt_scale = 100.0; // a scale factor applied to regression during training
253  const float reg_dxy_scale = 1.0; // a scale factor applied to regression during training
254 
255  prediction.at(0) = outputs[0].matrix<float>()(0, 0);
256  prediction.at(1) = outputs[0].matrix<float>()(0, 1);
257 
258  // Remove scale factor used during training
259  prediction.at(0) /= reg_pt_scale;
260  prediction.at(1) /= reg_dxy_scale;
261  return;
262 }
Session * createSession(SessionOptions &sessionOptions)
Definition: TensorFlow.cc:85
int Theta_fp() const
Definition: EMTFTrack.h:187
uint16_t sign_ph[6]
Definition: EMTFTrack.h:32
int Endcap() const
Definition: EMTFTrack.h:165
static PFTauRenderPlugin instance
bool verbose
GraphDef * loadGraphDef(const std::string &pbFile)
Definition: TensorFlow.cc:68
std::vector< std::string > outputNamesDxy_
constexpr int NUM_PREDICTIONS
Definition: Common.h:69
tensorflow::Session * sessionDxy_
uint16_t delta_ph[6]
Definition: EMTFTrack.h:30
void configure(int verbose, const std::string pbFileNameDxy)
static std::string const input
Definition: EdmProvDump.cc:47
tensorflow::GraphDef * graphDefDxy_
std::array< float, NUM_PREDICTIONS > Prediction
Definition: Common.h:72
tuple d
Definition: ztail.py:151
uint16_t delta_th[6]
Definition: EMTFTrack.h:31
int calcBendFromPattern(const int pattern, const int endcap) const
void run(Session *session, const NamedTensorList &inputs, const std::vector< std::string > &outputNames, std::vector< Tensor > *outputs, const thread::ThreadPoolOptions &threadPoolOptions)
Definition: TensorFlow.cc:213
bool closeSession(Session *&session)
Definition: TensorFlow.cc:198
Abs< T >::type abs(const T &t)
Definition: Abs.h:22
uint16_t cpattern[4]
Definition: EMTFTrack.h:34
uint16_t fr[4]
Definition: EMTFTrack.h:37
uint16_t sign_th[6]
Definition: EMTFTrack.h:33
const PtAssignmentEngineAux2017 & aux() const
uint16_t st1_ring2
Definition: EMTFTrack.h:28
virtual void preprocessing_dxy(const EMTFTrack &track, emtf::Feature &feature) const
#define emtf_assert(expr)
Definition: DebugTools.h:18
constexpr int NUM_FEATURES
Definition: Common.h:68
char data[epos_bytes_allocation]
Definition: EPOS_Wrapper.h:79
int Mode() const
Definition: EMTFTrack.h:168
std::array< float, NUM_FEATURES > Feature
Definition: Common.h:71
virtual void call_tensorflow_dxy(const emtf::Feature &feature, emtf::Prediction &prediction) const
EMTFPtLUT PtLUT() const
Definition: EMTFTrack.h:129
virtual void calculate_pt_dxy(const EMTFTrack &track, emtf::Feature &feature, emtf::Prediction &prediction) const