4 #ifndef ROOT_Math_CholeskyDecomp
5 #define ROOT_Math_CholeskyDecomp
26 namespace CholeskyDecompHelpers {
29 template<
class F,
unsigned N,
class M>
struct _inverter;
30 template<
class F,
unsigned N,
class V>
struct _solver;
82 {
return fArr[((i * (i + 1)) / 2) + j]; }
85 {
return fArr[((i * (i + 1)) / 2) + j]; }
99 fOk = _decomposer<F, N, M>()(
fL, m);
116 fOk = _decomposer<F, N, PackedArrayAdapter<G> >()(
125 operator bool()
const {
return fOk; }
136 template<
class V>
bool Solve(V& rhs)
const
139 if (
fOk) _solver<F,N,V>()(rhs,
fL);
return fOk;
152 if (
fOk) _inverter<F,N,M>()(m,
fL);
return fOk;
171 _inverter<F,N,PackedArrayAdapter<G> >()(adapted,
fL);
178 namespace CholeskyDecompHelpers {
180 template<
class F,
unsigned N,
class M>
struct _decomposer
199 for (
unsigned i = 0;
i <
N; base1 += ++
i) {
203 for (
unsigned j = 0;
j <
i; base2 += ++
j) {
205 for (
unsigned k =
j;
k--; )
206 tmp -= base1[
k] * base2[
k];
207 base1[
j] = tmp *= base2[
j];
209 tmpdiag += tmp *
tmp;
212 tmpdiag =
src(i, i) - tmpdiag;
214 if (tmpdiag <= F(0))
return false;
222 template<
class F,
unsigned N,
class M>
struct _inverter
228 F
l[
N * (
N + 1) / 2];
232 for (
unsigned i = 1;
i <
N; base1 += ++
i) {
233 for (
unsigned j = 0;
j <
i; ++
j) {
235 const F *base2 = &
l[(i * (i - 1)) / 2];
236 for (
unsigned k = i;
k-- >
j; base2 -=
k)
237 tmp -= base1[
k] * base2[j];
238 base1[
j] = tmp * base1[
i];
243 for (
unsigned i = N;
i--; ) {
244 for (
unsigned j =
i + 1;
j--; ) {
246 base1 = &
l[(N * (N - 1)) / 2];
247 for (
unsigned k = N;
k-- >
i; base1 -=
k)
248 tmp += base1[i] * base1[
j];
256 template<
class F,
unsigned N,
class V>
struct _solver
262 for (
unsigned k = 0;
k <
N; ++
k) {
263 const unsigned base = (
k * (
k + 1)) / 2;
265 for (
unsigned i =
k;
i--; )
266 sum += rhs[
i] * l[base +
i];
268 rhs[
k] = (rhs[
k] - sum) * l[base +
k];
271 for (
unsigned k = N;
k--; ) {
273 for (
unsigned i = N; --
i >
k; )
274 sum += rhs[
i] * l[(
i * (
i + 1)) / 2 +
k];
276 rhs[
k] = (rhs[
k] - sum) * l[(k * (k + 1)) / 2 + k];
287 if (
src(0,0) <= F(0))
return false;
289 dst[1] =
src(1,0) * dst[0];
290 dst[2] =
src(1,1) - dst[1] * dst[1];
291 if (dst[2] <= F(0))
return false;
293 dst[3] =
src(2,0) * dst[0];
294 dst[4] = (
src(2,1) - dst[1] * dst[3]) * dst[2];
295 dst[5] =
src(2,2) - (dst[3] * dst[3] + dst[4] * dst[4]);
296 if (dst[5] <= F(0))
return false;
298 dst[6] =
src(3,0) * dst[0];
299 dst[7] = (
src(3,1) - dst[1] * dst[6]) * dst[2];
300 dst[8] = (
src(3,2) - dst[3] * dst[6] - dst[4] * dst[7]) * dst[5];
301 dst[9] =
src(3,3) - (dst[6] * dst[6] + dst[7] * dst[7] + dst[8] * dst[8]);
302 if (dst[9] <= F(0))
return false;
304 dst[10] =
src(4,0) * dst[0];
305 dst[11] = (
src(4,1) - dst[1] * dst[10]) * dst[2];
306 dst[12] = (
src(4,2) - dst[3] * dst[10] - dst[4] * dst[11]) * dst[5];
307 dst[13] = (
src(4,3) - dst[6] * dst[10] - dst[7] * dst[11] - dst[8] * dst[12]) * dst[9];
308 dst[14] =
src(4,4) - (dst[10]*dst[10]+dst[11]*dst[11]+dst[12]*dst[12]+dst[13]*dst[13]);
309 if (dst[14] <= F(0))
return false;
310 else dst[14] =
std::sqrt(F(1) / dst[14]);
311 dst[15] =
src(5,0) * dst[0];
312 dst[16] = (
src(5,1) - dst[1] * dst[15]) * dst[2];
313 dst[17] = (
src(5,2) - dst[3] * dst[15] - dst[4] * dst[16]) * dst[5];
314 dst[18] = (
src(5,3) - dst[6] * dst[15] - dst[7] * dst[16] - dst[8] * dst[17]) * dst[9];
315 dst[19] = (
src(5,4) - dst[10] * dst[15] - dst[11] * dst[16] - dst[12] * dst[17] - dst[13] * dst[18]) * dst[14];
316 dst[20] =
src(5,5) - (dst[15]*dst[15]+dst[16]*dst[16]+dst[17]*dst[17]+dst[18]*dst[18]+dst[19]*dst[19]);
317 if (dst[20] <= F(0))
return false;
318 else dst[20] =
std::sqrt(F(1) / dst[20]);
328 if (
src(0,0) <= F(0))
return false;
330 dst[1] =
src(1,0) * dst[0];
331 dst[2] =
src(1,1) - dst[1] * dst[1];
332 if (dst[2] <= F(0))
return false;
334 dst[3] =
src(2,0) * dst[0];
335 dst[4] = (
src(2,1) - dst[1] * dst[3]) * dst[2];
336 dst[5] =
src(2,2) - (dst[3] * dst[3] + dst[4] * dst[4]);
337 if (dst[5] <= F(0))
return false;
339 dst[6] =
src(3,0) * dst[0];
340 dst[7] = (
src(3,1) - dst[1] * dst[6]) * dst[2];
341 dst[8] = (
src(3,2) - dst[3] * dst[6] - dst[4] * dst[7]) * dst[5];
342 dst[9] =
src(3,3) - (dst[6] * dst[6] + dst[7] * dst[7] + dst[8] * dst[8]);
343 if (dst[9] <= F(0))
return false;
345 dst[10] =
src(4,0) * dst[0];
346 dst[11] = (
src(4,1) - dst[1] * dst[10]) * dst[2];
347 dst[12] = (
src(4,2) - dst[3] * dst[10] - dst[4] * dst[11]) * dst[5];
348 dst[13] = (
src(4,3) - dst[6] * dst[10] - dst[7] * dst[11] - dst[8] * dst[12]) * dst[9];
349 dst[14] =
src(4,4) - (dst[10]*dst[10]+dst[11]*dst[11]+dst[12]*dst[12]+dst[13]*dst[13]);
350 if (dst[14] <= F(0))
return false;
351 else dst[14] =
std::sqrt(F(1) / dst[14]);
361 if (
src(0,0) <= F(0))
return false;
363 dst[1] =
src(1,0) * dst[0];
364 dst[2] =
src(1,1) - dst[1] * dst[1];
365 if (dst[2] <= F(0))
return false;
367 dst[3] =
src(2,0) * dst[0];
368 dst[4] = (
src(2,1) - dst[1] * dst[3]) * dst[2];
369 dst[5] =
src(2,2) - (dst[3] * dst[3] + dst[4] * dst[4]);
370 if (dst[5] <= F(0))
return false;
372 dst[6] =
src(3,0) * dst[0];
373 dst[7] = (
src(3,1) - dst[1] * dst[6]) * dst[2];
374 dst[8] = (
src(3,2) - dst[3] * dst[6] - dst[4] * dst[7]) * dst[5];
375 dst[9] =
src(3,3) - (dst[6] * dst[6] + dst[7] * dst[7] + dst[8] * dst[8]);
376 if (dst[9] <= F(0))
return false;
387 if (
src(0,0) <= F(0))
return false;
389 dst[1] =
src(1,0) * dst[0];
390 dst[2] =
src(1,1) - dst[1] * dst[1];
391 if (dst[2] <= F(0))
return false;
393 dst[3] =
src(2,0) * dst[0];
394 dst[4] = (
src(2,1) - dst[1] * dst[3]) * dst[2];
395 dst[5] =
src(2,2) - (dst[3] * dst[3] + dst[4] * dst[4]);
396 if (dst[5] <= F(0))
return false;
407 if (
src(0,0) <= F(0))
return false;
409 dst[1] =
src(1,0) * dst[0];
410 dst[2] =
src(1,1) - dst[1] * dst[1];
411 if (dst[2] <= F(0))
return false;
422 if (
src(0,0) <= F(0))
return false;
441 const F li21 = -src[1] * src[0] * src[2];
442 const F li32 = -src[4] * src[2] * src[5];
443 const F li31 = (src[1] * src[4] * src[2] - src[3]) * src[0] * src[5];
444 const F li43 = -src[8] * src[9] * src[5];
445 const F li42 = (src[4] * src[8] * src[5] - src[7]) * src[2] * src[9];
446 const F li41 = (-src[1] * src[4] * src[8] * src[2] * src[5] +
447 src[1] * src[7] * src[2] + src[3] * src[8] * src[5] - src[6]) * src[0] * src[9];
448 const F li54 = -src[13] * src[14] * src[9];
449 const F li53 = (src[13] * src[8] * src[9] - src[12]) * src[5] * src[14];
450 const F li52 = (-src[4] * src[8] * src[13] * src[5] * src[9] +
451 src[4] * src[12] * src[5] + src[7] * src[13] * src[9] - src[11]) * src[2] * src[14];
452 const F li51 = (src[1]*src[4]*src[8]*src[13]*src[2]*src[5]*src[9] -
453 src[13]*src[8]*src[3]*src[9]*src[5] - src[12]*src[4]*src[1]*src[2]*src[5] - src[13]*src[7]*src[1]*src[9]*src[2] +
454 src[11]*src[1]*src[2] + src[12]*src[3]*src[5] + src[13]*src[6]*src[9] -src[10]) * src[0] * src[14];
455 const F li65 = -src[19] * src[20] * src[14];
456 const F li64 = (src[19] * src[13] * src[14] - src[18]) * src[9] * src[20];
457 const F li63 = (-src[8] * src[13] * src[19] * src[9] * src[14] +
458 src[8] * src[18] * src[9] + src[12] * src[19] * src[14] - src[17]) * src[5] * src[20];
459 const F li62 = (src[4]*src[8]*src[13]*src[19]*src[5]*src[9]*src[14] -
460 src[18]*src[8]*src[4]*src[9]*src[5] - src[19]*src[12]*src[4]*src[14]*src[5] -src[19]*src[13]*src[7]*src[14]*src[9] +
461 src[17]*src[4]*src[5] + src[18]*src[7]*src[9] + src[19]*src[11]*src[14] - src[16]) * src[2] * src[20];
462 const F li61 = (-src[19]*src[13]*src[8]*src[4]*src[1]*src[2]*src[5]*src[9]*src[14] +
463 src[18]*src[8]*src[4]*src[1]*src[2]*src[5]*src[9] + src[19]*src[12]*src[4]*src[1]*src[2]*src[5]*src[14] +
464 src[19]*src[13]*src[7]*src[1]*src[2]*src[9]*src[14] + src[19]*src[13]*src[8]*src[3]*src[5]*src[9]*src[14] -
465 src[17]*src[4]*src[1]*src[2]*src[5] - src[18]*src[7]*src[1]*src[2]*src[9] - src[19]*src[11]*src[1]*src[2]*src[14] -
466 src[18]*src[8]*src[3]*src[5]*src[9] - src[19]*src[12]*src[3]*src[5]*src[14] - src[19]*src[13]*src[6]*src[9]*src[14] +
467 src[16]*src[1]*src[2] + src[17]*src[3]*src[5] + src[18]*src[6]*src[9] + src[19]*src[10]*src[14] - src[15]) *
470 dst(0,0) = li61*li61 + li51*li51 + li41*li41 + li31*li31 + li21*li21 + src[0]*src[0];
471 dst(1,0) = li61*li62 + li51*li52 + li41*li42 + li31*li32 + li21*src[2];
472 dst(1,1) = li62*li62 + li52*li52 + li42*li42 + li32*li32 + src[2]*src[2];
473 dst(2,0) = li61*li63 + li51*li53 + li41*li43 + li31*src[5];
474 dst(2,1) = li62*li63 + li52*li53 + li42*li43 + li32*src[5];
475 dst(2,2) = li63*li63 + li53*li53 + li43*li43 + src[5]*src[5];
476 dst(3,0) = li61*li64 + li51*li54 + li41*src[9];
477 dst(3,1) = li62*li64 + li52*li54 + li42*src[9];
478 dst(3,2) = li63*li64 + li53*li54 + li43*src[9];
479 dst(3,3) = li64*li64 + li54*li54 + src[9]*src[9];
480 dst(4,0) = li61*li65 + li51*src[14];
481 dst(4,1) = li62*li65 + li52*src[14];
482 dst(4,2) = li63*li65 + li53*src[14];
483 dst(4,3) = li64*li65 + li54*src[14];
484 dst(4,4) = li65*li65 + src[14]*src[14];
485 dst(5,0) = li61*src[20];
486 dst(5,1) = li62*src[20];
487 dst(5,2) = li63*src[20];
488 dst(5,3) = li64*src[20];
489 dst(5,4) = li65*src[20];
490 dst(5,5) = src[20]*src[20];
499 const F li21 = -src[1] * src[0] * src[2];
500 const F li32 = -src[4] * src[2] * src[5];
501 const F li31 = (src[1] * src[4] * src[2] - src[3]) * src[0] * src[5];
502 const F li43 = -src[8] * src[9] * src[5];
503 const F li42 = (src[4] * src[8] * src[5] - src[7]) * src[2] * src[9];
504 const F li41 = (-src[1] * src[4] * src[8] * src[2] * src[5] +
505 src[1] * src[7] * src[2] + src[3] * src[8] * src[5] - src[6]) * src[0] * src[9];
506 const F li54 = -src[13] * src[14] * src[9];
507 const F li53 = (src[13] * src[8] * src[9] - src[12]) * src[5] * src[14];
508 const F li52 = (-src[4] * src[8] * src[13] * src[5] * src[9] +
509 src[4] * src[12] * src[5] + src[7] * src[13] * src[9] - src[11]) * src[2] * src[14];
510 const F li51 = (src[1]*src[4]*src[8]*src[13]*src[2]*src[5]*src[9] -
511 src[13]*src[8]*src[3]*src[9]*src[5] - src[12]*src[4]*src[1]*src[2]*src[5] - src[13]*src[7]*src[1]*src[9]*src[2] +
512 src[11]*src[1]*src[2] + src[12]*src[3]*src[5] + src[13]*src[6]*src[9] -src[10]) * src[0] * src[14];
514 dst(0,0) = li51*li51 + li41*li41 + li31*li31 + li21*li21 + src[0]*src[0];
515 dst(1,0) = li51*li52 + li41*li42 + li31*li32 + li21*src[2];
516 dst(1,1) = li52*li52 + li42*li42 + li32*li32 + src[2]*src[2];
517 dst(2,0) = li51*li53 + li41*li43 + li31*src[5];
518 dst(2,1) = li52*li53 + li42*li43 + li32*src[5];
519 dst(2,2) = li53*li53 + li43*li43 + src[5]*src[5];
520 dst(3,0) = li51*li54 + li41*src[9];
521 dst(3,1) = li52*li54 + li42*src[9];
522 dst(3,2) = li53*li54 + li43*src[9];
523 dst(3,3) = li54*li54 + src[9]*src[9];
524 dst(4,0) = li51*src[14];
525 dst(4,1) = li52*src[14];
526 dst(4,2) = li53*src[14];
527 dst(4,3) = li54*src[14];
528 dst(4,4) = src[14]*src[14];
537 const F li21 = -src[1] * src[0] * src[2];
538 const F li32 = -src[4] * src[2] * src[5];
539 const F li31 = (src[1] * src[4] * src[2] - src[3]) * src[0] * src[5];
540 const F li43 = -src[8] * src[9] * src[5];
541 const F li42 = (src[4] * src[8] * src[5] - src[7]) * src[2] * src[9];
542 const F li41 = (-src[1] * src[4] * src[8] * src[2] * src[5] +
543 src[1] * src[7] * src[2] + src[3] * src[8] * src[5] - src[6]) * src[0] * src[9];
545 dst(0,0) = li41*li41 + li31*li31 + li21*li21 + src[0]*src[0];
546 dst(1,0) = li41*li42 + li31*li32 + li21*src[2];
547 dst(1,1) = li42*li42 + li32*li32 + src[2]*src[2];
548 dst(2,0) = li41*li43 + li31*src[5];
549 dst(2,1) = li42*li43 + li32*src[5];
550 dst(2,2) = li43*li43 + src[5]*src[5];
551 dst(3,0) = li41*src[9];
552 dst(3,1) = li42*src[9];
553 dst(3,2) = li43*src[9];
554 dst(3,3) = src[9]*src[9];
563 const F li21 = -src[1] * src[0] * src[2];
564 const F li32 = -src[4] * src[2] * src[5];
565 const F li31 = (src[1] * src[4] * src[2] - src[3]) * src[0] * src[5];
567 dst(0,0) = li31*li31 + li21*li21 + src[0]*src[0];
568 dst(1,0) = li31*li32 + li21*src[2];
569 dst(1,1) = li32*li32 + src[2]*src[2];
570 dst(2,0) = li31*src[5];
571 dst(2,1) = li32*src[5];
572 dst(2,2) = src[5]*src[5];
581 const F li21 = -src[1] * src[0] * src[2];
583 dst(0,0) = li21*li21 + src[0]*src[0];
584 dst(1,0) = li21*src[2];
585 dst(1,1) = src[2]*src[2];
594 dst(0,0) = src[0]*src[0];
606 template<
class F,
class V>
struct _solver<F,6,V>
612 const F y0 = rhs[0] * l[0];
613 const F y1 = (rhs[1]-l[1]*y0)*l[2];
614 const F y2 = (rhs[2]-(l[3]*y0+l[4]*y1))*l[5];
615 const F y3 = (rhs[3]-(l[6]*y0+l[7]*y1+l[8]*y2))*l[9];
616 const F y4 = (rhs[4]-(l[10]*y0+l[11]*y1+l[12]*y2+l[13]*y3))*l[14];
617 const F y5 = (rhs[5]-(l[15]*y0+l[16]*y1+l[17]*y2+l[18]*y3+l[19]*y4))*l[20];
620 rhs[4] = (y4-l[19]*rhs[5])*l[14];
621 rhs[3] = (y3-(l[18]*rhs[5]+l[13]*rhs[4]))*l[9];
622 rhs[2] = (y2-(l[17]*rhs[5]+l[12]*rhs[4]+l[8]*rhs[3]))*l[5];
623 rhs[1] = (y1-(l[16]*rhs[5]+l[11]*rhs[4]+l[7]*rhs[3]+l[4]*rhs[2]))*l[2];
624 rhs[0] = (y0-(l[15]*rhs[5]+l[10]*rhs[4]+l[6]*rhs[3]+l[3]*rhs[2]+l[1]*rhs[1]))*l[0];
628 template<
class F,
class V>
struct _solver<F,5,V>
634 const F y0 = rhs[0] * l[0];
635 const F y1 = (rhs[1]-l[1]*y0)*l[2];
636 const F y2 = (rhs[2]-(l[3]*y0+l[4]*y1))*l[5];
637 const F y3 = (rhs[3]-(l[6]*y0+l[7]*y1+l[8]*y2))*l[9];
638 const F y4 = (rhs[4]-(l[10]*y0+l[11]*y1+l[12]*y2+l[13]*y3))*l[14];
641 rhs[3] = (y3-(l[13]*rhs[4]))*l[9];
642 rhs[2] = (y2-(l[12]*rhs[4]+l[8]*rhs[3]))*l[5];
643 rhs[1] = (y1-(l[11]*rhs[4]+l[7]*rhs[3]+l[4]*rhs[2]))*l[2];
644 rhs[0] = (y0-(l[10]*rhs[4]+l[6]*rhs[3]+l[3]*rhs[2]+l[1]*rhs[1]))*l[0];
648 template<
class F,
class V>
struct _solver<F,4,V>
654 const F y0 = rhs[0] * l[0];
655 const F y1 = (rhs[1]-l[1]*y0)*l[2];
656 const F y2 = (rhs[2]-(l[3]*y0+l[4]*y1))*l[5];
657 const F y3 = (rhs[3]-(l[6]*y0+l[7]*y1+l[8]*y2))*l[9];
660 rhs[2] = (y2-(l[8]*rhs[3]))*l[5];
661 rhs[1] = (y1-(l[7]*rhs[3]+l[4]*rhs[2]))*l[2];
662 rhs[0] = (y0-(l[6]*rhs[3]+l[3]*rhs[2]+l[1]*rhs[1]))*l[0];
666 template<
class F,
class V>
struct _solver<F,3,V>
672 const F y0 = rhs[0] * l[0];
673 const F y1 = (rhs[1]-l[1]*y0)*l[2];
674 const F y2 = (rhs[2]-(l[3]*y0+l[4]*y1))*l[5];
677 rhs[1] = (y1-(l[4]*rhs[2]))*l[2];
678 rhs[0] = (y0-(l[3]*rhs[2]+l[1]*rhs[1]))*l[0];
682 template<
class F,
class V>
struct _solver<F,2,V>
688 const F y0 = rhs[0] * l[0];
689 const F y1 = (rhs[1]-l[1]*y0)*l[2];
692 rhs[0] = (y0-(l[1]*rhs[1]))*l[0];
696 template<
class F,
class V>
struct _solver<F,1,V>
702 rhs[0] *= l[0] * l[0];
706 template<
class F,
class V>
struct _solver<F,0,V>
719 #endif // ROOT_Math_CHOLESKYDECOMP
F fL[N *(N+1)/2]
lower triangular matrix L
bool Invert(M &m) const
place the inverse into m
PackedArrayAdapter(G *arr)
constructor
bool operator()(F *dst, const M &src) const
method to do the decomposition
const G operator()(unsigned i, unsigned j) const
read access to elements (make sure that j <= i)
CholeskyDecomp(G *m)
perform a Cholesky decomposition
void operator()(M &dst, const F *src) const
method to do the inversion
bool fOk
flag indicating a successful decomposition
class to compute the Cholesky decomposition of a matrix
bool operator()(F *dst, const M &src) const
method to do the decomposition
G * fArr
pointer to first array element
G & operator()(unsigned i, unsigned j)
write access to elements (make sure that j <= i)
void operator()(M &dst, const F *src) const
method to do the inversion
void operator()(M &dst, const F *src) const
method to do the inversion
adapter for packed arrays (to SMatrix indexing conventions)
void operator()(V &rhs, const F *l) const
method to solve the linear system
void operator()(M &dst, const F *src) const
method to do the inversion
bool operator()(F *dst, const M &src) const
method to do the decomposition
struct to obtain the inverse from a Cholesky decomposition
struct to do a Cholesky decomposition
bool operator()(F *dst, const M &src) const
method to do the decomposition
void operator()(V &rhs, const F *l) const
method to solve the linear system
bool Solve(V &rhs) const
solves a linear system for the given right hand side
bool Invert(G *m) const
place the inverse into m
void operator()(M &dst, const F *src) const
method to do the inversion
bool operator()(F *dst, const M &src) const
method to do the decomposition
void operator()(V &rhs, const F *l) const
method to solve the linear system
CholeskyDecomp(const M &m)
perform a Cholesky decomposition
bool operator()(F *dst, const M &src) const
method to do the decomposition
void operator()(V &rhs, const F *l) const
method to solve the linear system
void operator()(V &rhs, const F *l) const
method to solve the linear system
std::vector< std::vector< double > > tmp
void operator()(V &rhs, const F *l) const
method to solve the linear system
void operator()(V &rhs, const F *l) const
method to solve the linear system
void operator()(M &dst, const F *src) const
method to do the inversion
bool operator()(F *dst, const M &src) const
method to do the decomposition
bool ok() const
returns true if decomposition was successful
void operator()(M &dst, const F *src) const
method to do the inversion
struct to solve a linear system using its Cholesky decomposition