CMS 3D CMS Logo

NeuralNetwork.h
Go to the documentation of this file.
1 #ifndef RecoTracker_LSTCore_src_alpaka_NeuralNetwork_h
2 #define RecoTracker_LSTCore_src_alpaka_NeuralNetwork_h
3 
10 
11 #include "NeuralNetworkWeights.h"
12 
14 
15  namespace t5dnn {
16 
17  template <typename TAcc>
18  ALPAKA_FN_ACC ALPAKA_FN_INLINE float runInference(TAcc const& acc,
21  SegmentsConst segments,
22  TripletsConst triplets,
23  const float* xVec,
24  const float* yVec,
25  const unsigned int* mdIndices,
26  const uint16_t* lowerModuleIndices,
27  unsigned int innerTripletIndex,
28  unsigned int outerTripletIndex,
29  float innerRadius,
30  float outerRadius,
31  float bridgeRadius) {
32  // Unpack x-coordinates of hits
33  float x1 = xVec[0];
34  float x2 = xVec[1];
35  float x3 = xVec[2];
36  float x4 = xVec[3];
37  float x5 = xVec[4];
38  // Unpack y-coordinates of hits
39  float y1 = yVec[0];
40  float y2 = yVec[1];
41  float y3 = yVec[2];
42  float y4 = yVec[3];
43  float y5 = yVec[4];
44  // Unpack module indices
45  unsigned int mdIndex1 = mdIndices[0];
46  unsigned int mdIndex2 = mdIndices[1];
47  unsigned int mdIndex3 = mdIndices[2];
48  unsigned int mdIndex4 = mdIndices[3];
49  unsigned int mdIndex5 = mdIndices[4];
50  // Unpack module indices
51  uint16_t lowerModuleIndex1 = lowerModuleIndices[0];
52  uint16_t lowerModuleIndex2 = lowerModuleIndices[1];
53  uint16_t lowerModuleIndex3 = lowerModuleIndices[2];
54  uint16_t lowerModuleIndex4 = lowerModuleIndices[3];
55  uint16_t lowerModuleIndex5 = lowerModuleIndices[4];
56  // Compute some convenience variables
57  short layer2_adjustment = 0;
58  if (modules.layers()[lowerModuleIndex1] == 1) {
59  layer2_adjustment = 1; // get upper segment to be in second layer
60  }
61  unsigned int md_idx_for_t5_eta_phi =
62  segments.mdIndices()[triplets.segmentIndices()[innerTripletIndex][0]][layer2_adjustment];
63  bool is_endcap1 = (modules.subdets()[lowerModuleIndex1] == 4); // true if anchor hit 1 is in the endcap
64  bool is_endcap2 = (modules.subdets()[lowerModuleIndex2] == 4); // true if anchor hit 2 is in the endcap
65  bool is_endcap3 = (modules.subdets()[lowerModuleIndex3] == 4); // true if anchor hit 3 is in the endcap
66  bool is_endcap4 = (modules.subdets()[lowerModuleIndex4] == 4); // true if anchor hit 4 is in the endcap
67  bool is_endcap5 = (modules.subdets()[lowerModuleIndex5] == 4); // true if anchor hit 5 is in the endcap
68 
69  // Build DNN input vector (corresponding output N-tuple branch noted in parenthetical in comment)
70  float x[38] = {
71  alpaka::math::log10(acc, 2 * k2Rinv1GeVf * innerRadius), // inner T3 pT (t3_pt)
72  mds.anchorEta()[mdIndex1], // inner T3 anchor hit 1 eta (t3_0_eta)
73  mds.anchorPhi()[mdIndex1], // inner T3 anchor hit 1 phi (t3_0_phi)
74  mds.anchorZ()[mdIndex1], // inner T3 anchor hit 1 z (t3_0_z)
75  alpaka::math::sqrt(acc, x1 * x1 + y1 * y1), // inner T3 anchor hit 1 r (t3_0_r)
76  float(modules.layers()[lowerModuleIndex1] + 6 * is_endcap1), // inner T3 anchor hit 1 layer (t3_0_layer)
77  mds.anchorEta()[mdIndex2], // inner T3 anchor hit 2 eta (t3_2_eta)
78  mds.anchorPhi()[mdIndex2], // inner T3 anchor hit 2 phi (t3_2_phi)
79  mds.anchorZ()[mdIndex2], // inner T3 anchor hit 2 z (t3_2_z)
80  alpaka::math::sqrt(acc, x2 * x2 + y2 * y2), // inner T3 anchor hit 2 r (t3_2_r)
81  float(modules.layers()[lowerModuleIndex2] + 6 * is_endcap2), // inner T3 anchor hit 2 layer (t3_2_layer)
82  mds.anchorEta()[mdIndex3], // inner T3 anchor hit 3 eta (t3_4_eta)
83  mds.anchorPhi()[mdIndex3], // inner T3 anchor hit 3 phi (t3_4_phi)
84  mds.anchorZ()[mdIndex3], // inner T3 anchor hit 3 z (t3_4_z)
85  alpaka::math::sqrt(acc, x3 * x3 + y3 * y3), // inner T3 anchor hit 3 r (t3_4_r)
86  float(modules.layers()[lowerModuleIndex3] + 6 * is_endcap3), // inner T3 anchor hit 3 layer (t3_4_layer)
87  alpaka::math::log10(acc, 2 * k2Rinv1GeVf * outerRadius), // outer T3 pT (t3_pt)
88  mds.anchorEta()[mdIndex3], // outer T3 anchor hit 4 eta (t3_0_eta)
89  mds.anchorPhi()[mdIndex3], // outer T3 anchor hit 4 phi (t3_0_phi)
90  mds.anchorZ()[mdIndex3], // outer T3 anchor hit 3 eta (t3_0_z)
91  alpaka::math::sqrt(acc, x3 * x3 + y3 * y3), // outer T3 anchor hit 3 r (t3_0_r)
92  float(modules.layers()[lowerModuleIndex3] + 6 * is_endcap3), // outer T3 anchor hit 3 layer (t3_0_layer)
93  mds.anchorEta()[mdIndex4], // outer T3 anchor hit 4 eta (t3_2_eta)
94  mds.anchorPhi()[mdIndex4], // outer T3 anchor hit 4 phi (t3_2_phi)
95  mds.anchorZ()[mdIndex4], // outer T3 anchor hit 4 z (t3_2_z)
96  alpaka::math::sqrt(acc, x4 * x4 + y4 * y4), // outer T3 anchor hit 4 r (t3_2_r)
97  float(modules.layers()[lowerModuleIndex4] + 6 * is_endcap4), // outer T3 anchor hit 4 layer (t3_2_layer)
98  mds.anchorEta()[mdIndex5], // outer T3 anchor hit 5 eta (t3_4_eta)
99  mds.anchorPhi()[mdIndex5], // outer T3 anchor hit 5 phi (t3_4_phi)
100  mds.anchorZ()[mdIndex5], // outer T3 anchor hit 5 z (t3_4_z)
101  alpaka::math::sqrt(acc, x5 * x5 + y5 * y5), // outer T3 anchor hit 5 r (t3_4_r)
102  float(modules.layers()[lowerModuleIndex5] + 6 * is_endcap5), // outer T3 anchor hit 5 layer (t3_4_layer)
103  alpaka::math::log10(acc, (innerRadius + outerRadius) * k2Rinv1GeVf), // T5 pT (t5_pt)
104  mds.anchorEta()[md_idx_for_t5_eta_phi], // T5 eta (t5_eta)
105  mds.anchorPhi()[md_idx_for_t5_eta_phi], // T5 phi (t5_phi)
106  alpaka::math::log10(acc, innerRadius), // T5 inner radius (t5_innerRadius)
107  alpaka::math::log10(acc, bridgeRadius), // T5 bridge radius (t5_bridgeRadius)
108  alpaka::math::log10(acc, outerRadius) // T5 outer radius (t5_outerRadius)
109  };
110 
111  // (0): Linear(in_features=38, out_features=32, bias=True) => x = x*W_T + b
112  float x_0[32];
113  for (unsigned int col = 0; col < 32; ++col) {
114  x_0[col] = 0;
115  for (unsigned int inner = 0; inner < 38; ++inner) {
116  x_0[col] += x[inner] * wgtT_0[inner][col];
117  }
118  x_0[col] += bias_0[col];
119  }
120 
121  // (1): ReLU()
122  float x_1[32];
123  for (unsigned int col = 0; col < 32; ++col) {
124  x_1[col] = (x_0[col] > 0.f) ? x_0[col] : 0.f;
125  }
126 
127  // (2): Linear(in_features=32, out_features=32, bias=True) => x = x*W_T + b
128  float x_2[32];
129  for (unsigned int col = 0; col < 32; ++col) {
130  x_2[col] = 0;
131  for (unsigned int inner = 0; inner < 32; ++inner) {
132  x_2[col] += x_1[inner] * wgtT_2[inner][col];
133  }
134  x_2[col] += bias_2[col];
135  }
136 
137  // (3): ReLU()
138  float x_3[32];
139  for (unsigned int col = 0; col < 32; ++col) {
140  x_3[col] = (x_2[col] > 0.f) ? x_2[col] : 0.f;
141  }
142 
143  // (4): Linear(in_features=32, out_features=1, bias=True) => x = x*W_T + b
144  float x_4[1];
145  for (unsigned int col = 0; col < 1; ++col) {
146  x_4[col] = 0;
147  for (unsigned int inner = 0; inner < 32; ++inner) {
148  x_4[col] += x_3[inner] * wgtT_4[inner][col];
149  }
150  x_4[col] += bias_4[col];
151  }
152 
153  // (5): Sigmoid()
154  float x_5[1];
155  for (unsigned int col = 0; col < 1; ++col) {
156  x_5[col] = alpaka::math::exp(acc, x_4[col]) / (alpaka::math::exp(acc, x_4[col]) + 1);
157  }
158 
159  return x_5[0];
160  }
161 
162  } // namespace t5dnn
163 } // namespace ALPAKA_ACCELERATOR_NAMESPACE::lst
164 
165 #endif
ALPAKA_FN_ACC ALPAKA_FN_INLINE float runInference(TAcc const &acc, ModulesConst modules, MiniDoubletsConst mds, SegmentsConst segments, TripletsConst triplets, const float *xVec, const float *yVec, const unsigned int *mdIndices, const uint16_t *lowerModuleIndices, unsigned int innerTripletIndex, unsigned int outerTripletIndex, float innerRadius, float outerRadius, float bridgeRadius)
Definition: NeuralNetwork.h:18
ALPAKA_STATIC_ACC_MEM_GLOBAL constexpr float k2Rinv1GeVf
Definition: Common.h:49
TripletsSoA::ConstView TripletsConst
Definition: TripletsSoA.h:31
ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_0[38][32]
T sqrt(T t)
Definition: SSEVec.h:23
double f[11][100]
ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_4[1]
MiniDoubletsSoA::ConstView MiniDoubletsConst
ModulesSoA::ConstView ModulesConst
Definition: ModulesSoA.h:47
ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_4[32][1]
col
Definition: cuy.py:1009
SegmentsSoA::ConstView SegmentsConst
Definition: SegmentsSoA.h:49
float x
ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_2[32]
ALPAKA_STATIC_ACC_MEM_GLOBAL const float wgtT_2[32][32]
ALPAKA_STATIC_ACC_MEM_GLOBAL const float bias_0[32]