CMS 3D CMS Logo

choleskyInversion.h
Go to the documentation of this file.
1 #ifndef DataFormat_Math_choleskyInversion_h
2 #define DataFormat_Math_choleskyInversion_h
3 
4 #include <cmath>
5 
6 #include <Eigen/Core>
7 
8 namespace math {
9  namespace cholesky {
10 
11  template <typename M1, typename M2, int N = M2::ColsAtCompileTime>
12 // without this: either does not compile or compiles and then fails silently at runtime
13 #ifdef __CUDACC__
15 #endif
16  inline constexpr void
17  invertNN(M1 const& src, M2& dst) {
18 
19  // origin: CERNLIB
20 
21  using T = typename M2::Scalar;
22 
23  T a[N][N];
24  for (int i = 0; i < N; ++i) {
25  a[i][i] = src(i, i);
26  for (int j = i + 1; j < N; ++j)
27  a[j][i] = src(i, j);
28  }
29 
30  for (int j = 0; j < N; ++j) {
31  a[j][j] = T(1.) / a[j][j];
32  int jp1 = j + 1;
33  for (int l = jp1; l < N; ++l) {
34  a[j][l] = a[j][j] * a[l][j];
35  T s1 = -a[l][jp1];
36  for (int i = 0; i < jp1; ++i)
37  s1 += a[l][i] * a[i][jp1];
38  a[l][jp1] = -s1;
39  }
40  }
41 
42  if constexpr (N == 1) {
43  dst(0, 0) = a[0][0];
44  return;
45  }
46  a[0][1] = -a[0][1];
47  a[1][0] = a[0][1] * a[1][1];
48  for (int j = 2; j < N; ++j) {
49  int jm1 = j - 1;
50  for (int k = 0; k < jm1; ++k) {
51  T s31 = a[k][j];
52  for (int i = k; i < jm1; ++i)
53  s31 += a[k][i + 1] * a[i + 1][j];
54  a[k][j] = -s31;
55  a[j][k] = -s31 * a[j][j];
56  }
57  a[jm1][j] = -a[jm1][j];
58  a[j][jm1] = a[jm1][j] * a[j][j];
59  }
60 
61  int j = 0;
62  while (j < N - 1) {
63  T s33 = a[j][j];
64  for (int i = j + 1; i < N; ++i)
65  s33 += a[j][i] * a[i][j];
66  dst(j, j) = s33;
67 
68  ++j;
69  for (int k = 0; k < j; ++k) {
70  T s32 = 0;
71  for (int i = j; i < N; ++i)
72  s32 += a[k][i] * a[i][j];
73  dst(k, j) = dst(j, k) = s32;
74  }
75  }
76  dst(j, j) = a[j][j];
77  }
78 
91  template <typename M1, typename M2>
92  inline constexpr void __attribute__((always_inline)) invert11(M1 const& src, M2& dst) {
93  using F = decltype(src(0, 0));
94  dst(0, 0) = F(1.0) / src(0, 0);
95  }
96 
97  template <typename M1, typename M2>
98  inline constexpr void __attribute__((always_inline)) invert22(M1 const& src, M2& dst) {
99  using F = decltype(src(0, 0));
100  auto luc0 = F(1.0) / src(0, 0);
101  auto luc1 = src(1, 0) * src(1, 0) * luc0;
102  auto luc2 = F(1.0) / (src(1, 1) - luc1);
103 
104  auto li21 = luc1 * luc0 * luc2;
105 
106  dst(0, 0) = li21 + luc0;
107  dst(1, 0) = -src(1, 0) * luc0 * luc2;
108  dst(1, 1) = luc2;
109  }
110 
111  template <typename M1, typename M2>
112  inline constexpr void __attribute__((always_inline)) invert33(M1 const& src, M2& dst) {
113  using F = decltype(src(0, 0));
114  auto luc0 = F(1.0) / src(0, 0);
115  auto luc1 = src(1, 0);
116  auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
117  luc2 = F(1.0) / luc2;
118  auto luc3 = src(2, 0);
119  auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
120  auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + (luc2 * luc4) * luc4);
121  luc5 = F(1.0) / luc5;
122 
123  auto li21 = -luc0 * luc1;
124  auto li32 = -(luc2 * luc4);
125  auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
126 
127  dst(0, 0) = luc5 * li31 * li31 + li21 * li21 * luc2 + luc0;
128  dst(1, 0) = luc5 * li31 * li32 + li21 * luc2;
129  dst(1, 1) = luc5 * li32 * li32 + luc2;
130  dst(2, 0) = luc5 * li31;
131  dst(2, 1) = luc5 * li32;
132  dst(2, 2) = luc5;
133  }
134 
135  template <typename M1, typename M2>
136  inline constexpr void __attribute__((always_inline)) invert44(M1 const& src, M2& dst) {
137  using F = decltype(src(0, 0));
138  auto luc0 = F(1.0) / src(0, 0);
139  auto luc1 = src(1, 0);
140  auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
141  luc2 = F(1.0) / luc2;
142  auto luc3 = src(2, 0);
143  auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
144  auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + luc2 * luc4 * luc4);
145  luc5 = F(1.0) / luc5;
146  auto luc6 = src(3, 0);
147  auto luc7 = (src(3, 1) - luc0 * luc1 * luc6);
148  auto luc8 = (src(3, 2) - luc0 * luc3 * luc6 - luc2 * luc4 * luc7);
149  auto luc9 = src(3, 3) - (luc0 * luc6 * luc6 + luc2 * luc7 * luc7 + luc8 * (luc8 * luc5));
150  luc9 = F(1.0) / luc9;
151 
152  auto li21 = -luc1 * luc0;
153  auto li32 = -luc2 * luc4;
154  auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
155  auto li43 = -(luc8 * luc5);
156  auto li42 = (luc4 * luc8 * luc5 - luc7) * luc2;
157  auto li41 = (-luc1 * (luc2 * luc4) * (luc8 * luc5) + luc1 * (luc2 * luc7) + luc3 * (luc8 * luc5) - luc6) * luc0;
158 
159  dst(0, 0) = luc9 * li41 * li41 + luc5 * li31 * li31 + luc2 * li21 * li21 + luc0;
160  dst(1, 0) = luc9 * li41 * li42 + luc5 * li31 * li32 + luc2 * li21;
161  dst(1, 1) = luc9 * li42 * li42 + luc5 * li32 * li32 + luc2;
162  dst(2, 0) = luc9 * li41 * li43 + luc5 * li31;
163  dst(2, 1) = luc9 * li42 * li43 + luc5 * li32;
164  dst(2, 2) = luc9 * li43 * li43 + luc5;
165  dst(3, 0) = luc9 * li41;
166  dst(3, 1) = luc9 * li42;
167  dst(3, 2) = luc9 * li43;
168  dst(3, 3) = luc9;
169  }
170 
171  template <typename M1, typename M2>
172  inline constexpr void __attribute__((always_inline)) invert55(M1 const& src, M2& dst) {
173  using F = decltype(src(0, 0));
174  auto luc0 = F(1.0) / src(0, 0);
175  auto luc1 = src(1, 0);
176  auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
177  luc2 = F(1.0) / luc2;
178  auto luc3 = src(2, 0);
179  auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
180  auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + luc2 * luc4 * luc4);
181  luc5 = F(1.0) / luc5;
182  auto luc6 = src(3, 0);
183  auto luc7 = (src(3, 1) - luc0 * luc1 * luc6);
184  auto luc8 = (src(3, 2) - luc0 * luc3 * luc6 - luc2 * luc4 * luc7);
185  auto luc9 = src(3, 3) - (luc0 * luc6 * luc6 + luc2 * luc7 * luc7 + luc8 * (luc8 * luc5));
186  luc9 = F(1.0) / luc9;
187  auto luc10 = src(4, 0);
188  auto luc11 = (src(4, 1) - luc0 * luc1 * luc10);
189  auto luc12 = (src(4, 2) - luc0 * luc3 * luc10 - luc2 * luc4 * luc11);
190  auto luc13 = (src(4, 3) - luc0 * luc6 * luc10 - luc2 * luc7 * luc11 - luc5 * luc8 * luc12);
191  auto luc14 =
192  src(4, 4) - (luc0 * luc10 * luc10 + luc2 * luc11 * luc11 + luc5 * luc12 * luc12 + luc9 * luc13 * luc13);
193  luc14 = F(1.0) / luc14;
194 
195  auto li21 = -luc1 * luc0;
196  auto li32 = -luc2 * luc4;
197  auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
198  auto li43 = -(luc8 * luc5);
199  auto li42 = (luc4 * luc8 * luc5 - luc7) * luc2;
200  auto li41 = (-luc1 * (luc2 * luc4) * (luc8 * luc5) + luc1 * (luc2 * luc7) + luc3 * (luc8 * luc5) - luc6) * luc0;
201  auto li54 = -luc13 * luc9;
202  auto li53 = (luc13 * luc8 * luc9 - luc12) * luc5;
203  auto li52 = (-luc4 * luc8 * luc13 * luc5 * luc9 + luc4 * luc12 * luc5 + luc7 * luc13 * luc9 - luc11) * luc2;
204  auto li51 = (luc1 * luc4 * luc8 * luc13 * luc2 * luc5 * luc9 - luc13 * luc8 * luc3 * luc9 * luc5 -
205  luc12 * luc4 * luc1 * luc2 * luc5 - luc13 * luc7 * luc1 * luc9 * luc2 + luc11 * luc1 * luc2 +
206  luc12 * luc3 * luc5 + luc13 * luc6 * luc9 - luc10) *
207  luc0;
208 
209  dst(0, 0) = luc14 * li51 * li51 + luc9 * li41 * li41 + luc5 * li31 * li31 + luc2 * li21 * li21 + luc0;
210  dst(1, 0) = luc14 * li51 * li52 + luc9 * li41 * li42 + luc5 * li31 * li32 + luc2 * li21;
211  dst(1, 1) = luc14 * li52 * li52 + luc9 * li42 * li42 + luc5 * li32 * li32 + luc2;
212  dst(2, 0) = luc14 * li51 * li53 + luc9 * li41 * li43 + luc5 * li31;
213  dst(2, 1) = luc14 * li52 * li53 + luc9 * li42 * li43 + luc5 * li32;
214  dst(2, 2) = luc14 * li53 * li53 + luc9 * li43 * li43 + luc5;
215  dst(3, 0) = luc14 * li51 * li54 + luc9 * li41;
216  dst(3, 1) = luc14 * li52 * li54 + luc9 * li42;
217  dst(3, 2) = luc14 * li53 * li54 + luc9 * li43;
218  dst(3, 3) = luc14 * li54 * li54 + luc9;
219  dst(4, 0) = luc14 * li51;
220  dst(4, 1) = luc14 * li52;
221  dst(4, 2) = luc14 * li53;
222  dst(4, 3) = luc14 * li54;
223  dst(4, 4) = luc14;
224  }
225 
226  template <typename M1, typename M2>
227  inline constexpr void __attribute__((always_inline)) invert66(M1 const& src, M2& dst) {
228  using F = decltype(src(0, 0));
229  auto luc0 = F(1.0) / src(0, 0);
230  auto luc1 = src(1, 0);
231  auto luc2 = src(1, 1) - luc0 * luc1 * luc1;
232  luc2 = F(1.0) / luc2;
233  auto luc3 = src(2, 0);
234  auto luc4 = (src(2, 1) - luc0 * luc1 * luc3);
235  auto luc5 = src(2, 2) - (luc0 * luc3 * luc3 + luc2 * luc4 * luc4);
236  luc5 = F(1.0) / luc5;
237  auto luc6 = src(3, 0);
238  auto luc7 = (src(3, 1) - luc0 * luc1 * luc6);
239  auto luc8 = (src(3, 2) - luc0 * luc3 * luc6 - luc2 * luc4 * luc7);
240  auto luc9 = src(3, 3) - (luc0 * luc6 * luc6 + luc2 * luc7 * luc7 + luc8 * (luc8 * luc5));
241  luc9 = F(1.0) / luc9;
242  auto luc10 = src(4, 0);
243  auto luc11 = (src(4, 1) - luc0 * luc1 * luc10);
244  auto luc12 = (src(4, 2) - luc0 * luc3 * luc10 - luc2 * luc4 * luc11);
245  auto luc13 = (src(4, 3) - luc0 * luc6 * luc10 - luc2 * luc7 * luc11 - luc5 * luc8 * luc12);
246  auto luc14 =
247  src(4, 4) - (luc0 * luc10 * luc10 + luc2 * luc11 * luc11 + luc5 * luc12 * luc12 + luc9 * luc13 * luc13);
248  luc14 = F(1.0) / luc14;
249  auto luc15 = src(5, 0);
250  auto luc16 = (src(5, 1) - luc0 * luc1 * luc15);
251  auto luc17 = (src(5, 2) - luc0 * luc3 * luc15 - luc2 * luc4 * luc16);
252  auto luc18 = (src(5, 3) - luc0 * luc6 * luc15 - luc2 * luc7 * luc16 - luc5 * luc8 * luc17);
253  auto luc19 =
254  (src(5, 4) - luc0 * luc10 * luc15 - luc2 * luc11 * luc16 - luc5 * luc12 * luc17 - luc9 * luc13 * luc18);
255  auto luc20 = src(5, 5) - (luc0 * luc15 * luc15 + luc2 * luc16 * luc16 + luc5 * luc17 * luc17 +
256  luc9 * luc18 * luc18 + luc14 * luc19 * luc19);
257  luc20 = F(1.0) / luc20;
258 
259  auto li21 = -luc1 * luc0;
260  auto li32 = -luc2 * luc4;
261  auto li31 = (luc1 * (luc2 * luc4) - luc3) * luc0;
262  auto li43 = -(luc8 * luc5);
263  auto li42 = (luc4 * luc8 * luc5 - luc7) * luc2;
264  auto li41 = (-luc1 * (luc2 * luc4) * (luc8 * luc5) + luc1 * (luc2 * luc7) + luc3 * (luc8 * luc5) - luc6) * luc0;
265  auto li54 = -luc13 * luc9;
266  auto li53 = (luc13 * luc8 * luc9 - luc12) * luc5;
267  auto li52 = (-luc4 * luc8 * luc13 * luc5 * luc9 + luc4 * luc12 * luc5 + luc7 * luc13 * luc9 - luc11) * luc2;
268  auto li51 = (luc1 * luc4 * luc8 * luc13 * luc2 * luc5 * luc9 - luc13 * luc8 * luc3 * luc9 * luc5 -
269  luc12 * luc4 * luc1 * luc2 * luc5 - luc13 * luc7 * luc1 * luc9 * luc2 + luc11 * luc1 * luc2 +
270  luc12 * luc3 * luc5 + luc13 * luc6 * luc9 - luc10) *
271  luc0;
272 
273  auto li65 = -luc19 * luc14;
274  auto li64 = (luc19 * luc14 * luc13 - luc18) * luc9;
275  auto li63 =
276  (-luc8 * luc13 * (luc19 * luc14) * luc9 + luc8 * luc9 * luc18 + luc12 * (luc19 * luc14) - luc17) * luc5;
277  auto li62 = (luc4 * (luc8 * luc9) * luc13 * luc5 * (luc19 * luc14) - luc18 * luc4 * (luc8 * luc9) * luc5 -
278  luc19 * luc12 * luc4 * luc14 * luc5 - luc19 * luc13 * luc7 * luc14 * luc9 + luc17 * luc4 * luc5 +
279  luc18 * luc7 * luc9 + luc19 * luc11 * luc14 - luc16) *
280  luc2;
281  auto li61 =
282  (-luc19 * luc13 * luc8 * luc4 * luc1 * luc2 * luc5 * luc9 * luc14 +
283  luc18 * luc8 * luc4 * luc1 * luc2 * luc5 * luc9 + luc19 * luc12 * luc4 * luc1 * luc2 * luc5 * luc14 +
284  luc19 * luc13 * luc7 * luc1 * luc2 * luc9 * luc14 + luc19 * luc13 * luc8 * luc3 * luc5 * luc9 * luc14 -
285  luc17 * luc4 * luc1 * luc2 * luc5 - luc18 * luc7 * luc1 * luc2 * luc9 - luc19 * luc11 * luc1 * luc2 * luc14 -
286  luc18 * luc8 * luc3 * luc5 * luc9 - luc19 * luc12 * luc3 * luc5 * luc14 -
287  luc19 * luc13 * luc6 * luc9 * luc14 + luc16 * luc1 * luc2 + luc17 * luc3 * luc5 + luc18 * luc6 * luc9 +
288  luc19 * luc10 * luc14 - luc15) *
289  luc0;
290 
291  dst(0, 0) = luc20 * li61 * li61 + luc14 * li51 * li51 + luc9 * li41 * li41 + luc5 * li31 * li31 +
292  luc2 * li21 * li21 + luc0;
293  dst(1, 0) = luc20 * li61 * li62 + luc14 * li51 * li52 + luc9 * li41 * li42 + luc5 * li31 * li32 + luc2 * li21;
294  dst(1, 1) = luc20 * li62 * li62 + luc14 * li52 * li52 + luc9 * li42 * li42 + luc5 * li32 * li32 + luc2;
295  dst(2, 0) = luc20 * li61 * li63 + luc14 * li51 * li53 + luc9 * li41 * li43 + luc5 * li31;
296  dst(2, 1) = luc20 * li62 * li63 + luc14 * li52 * li53 + luc9 * li42 * li43 + luc5 * li32;
297  dst(2, 2) = luc20 * li63 * li63 + luc14 * li53 * li53 + luc9 * li43 * li43 + luc5;
298  dst(3, 0) = luc20 * li61 * li64 + luc14 * li51 * li54 + luc9 * li41;
299  dst(3, 1) = luc20 * li62 * li64 + luc14 * li52 * li54 + luc9 * li42;
300  dst(3, 2) = luc20 * li63 * li64 + luc14 * li53 * li54 + luc9 * li43;
301  dst(3, 3) = luc20 * li64 * li64 + luc14 * li54 * li54 + luc9;
302  dst(4, 0) = luc20 * li61 * li65 + luc14 * li51;
303  dst(4, 1) = luc20 * li62 * li65 + luc14 * li52;
304  dst(4, 2) = luc20 * li63 * li65 + luc14 * li53;
305  dst(4, 3) = luc20 * li64 * li65 + luc14 * li54;
306  dst(4, 4) = luc20 * li65 * li65 + luc14;
307  dst(5, 0) = luc20 * li61;
308  dst(5, 1) = luc20 * li62;
309  dst(5, 2) = luc20 * li63;
310  dst(5, 3) = luc20 * li64;
311  dst(5, 4) = luc20 * li65;
312  dst(5, 5) = luc20;
313  }
314 
315  template <typename M>
316  inline constexpr void symmetrize11(M& dst) {}
317 
318  template <typename M>
319  inline constexpr void symmetrize22(M& dst) {
320  dst(0, 1) = dst(1, 0);
321  }
322 
323  template <typename M>
324  inline constexpr void symmetrize33(M& dst) {
325  symmetrize22(dst);
326  dst(0, 2) = dst(2, 0);
327  dst(1, 2) = dst(2, 1);
328  }
329 
330  template <typename M>
331  inline constexpr void symmetrize44(M& dst) {
332  symmetrize33(dst);
333  dst(0, 3) = dst(3, 0);
334  dst(1, 3) = dst(3, 1);
335  dst(2, 3) = dst(3, 2);
336  }
337 
338  template <typename M>
339  inline constexpr void symmetrize55(M& dst) {
340  symmetrize44(dst);
341  dst(0, 4) = dst(4, 0);
342  dst(1, 4) = dst(4, 1);
343  dst(2, 4) = dst(4, 2);
344  dst(3, 4) = dst(4, 3);
345  }
346 
347  template <typename M>
348  inline constexpr void symmetrize66(M& dst) {
349  symmetrize55(dst);
350  dst(0, 5) = dst(5, 0);
351  dst(1, 5) = dst(5, 1);
352  dst(2, 5) = dst(5, 2);
353  dst(3, 5) = dst(5, 3);
354  dst(4, 5) = dst(5, 4);
355  }
356 
357  template <typename M1, typename M2, int N>
358  struct Inverter {
359  static constexpr void eval(M1 const& src, M2& dst) { dst = src.inverse(); }
360  };
361 
362  template <typename M1, typename M2>
363  struct Inverter<M1, M2, 1> {
364  static constexpr void eval(M1 const& src, M2& dst) { invert11(src, dst); }
365  };
366 
367  template <typename M1, typename M2>
368  struct Inverter<M1, M2, 2> {
369  static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
370  invert22(src, dst);
371  symmetrize22(dst);
372  }
373  };
374 
375  template <typename M1, typename M2>
376  struct Inverter<M1, M2, 3> {
377  static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
378  invert33(src, dst);
379  symmetrize33(dst);
380  }
381  };
382 
383  template <typename M1, typename M2>
384  struct Inverter<M1, M2, 4> {
385  static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
386  invert44(src, dst);
387  symmetrize44(dst);
388  }
389  };
390 
391  template <typename M1, typename M2>
392  struct Inverter<M1, M2, 5> {
393  static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
394  invert55(src, dst);
395  symmetrize55(dst);
396  }
397  };
398 
399  template <typename M1, typename M2>
400  struct Inverter<M1, M2, 6> {
401  static constexpr void __attribute__((always_inline)) eval(M1 const& src, M2& dst) {
402  invert66(src, dst);
403  symmetrize66(dst);
404  }
405  };
406 
407  // Eigen interface
408  template <typename M1, typename M2>
409  inline constexpr void __attribute__((always_inline)) invert(M1 const& src, M2& dst) {
410  if constexpr (M2::ColsAtCompileTime < 7)
411  Inverter<M1, M2, M2::ColsAtCompileTime>::eval(src, dst);
412  else
413  invertNN(src, dst);
414  }
415 
416  } // namespace cholesky
417 } // namespace math
418 
419 #endif // DataFormat_Math_choleskyInversion_h
double Scalar
Definition: Definitions.h:25
#define __host__
constexpr void symmetrize11(M &dst)
constexpr void symmetrize55(M &dst)
constexpr void symmetrize44(M &dst)
static constexpr void eval(M1 const &src, M2 &dst)
static constexpr void eval(M1 const &src, M2 &dst)
constexpr void symmetrize33(M &dst)
constexpr void invertNN(M1 const &src, M2 &dst)
#define N
Definition: blowfish.cc:9
constexpr void __attribute__((always_inline)) invert11(M1 const &src
constexpr void M2 & dst
double a
Definition: hdecay.h:119
constexpr void symmetrize22(M &dst)
static uInt32 F(BLOWFISH_CTX *ctx, uInt32 x)
Definition: blowfish.cc:163
long double T
#define __device__
constexpr void symmetrize66(M &dst)