1 #ifndef RecoTracker_MkFitCore_src_Matriplex_MatriplexCommon_h
2 #define RecoTracker_MkFitCore_src_Matriplex_MatriplexCommon_h
13 #if defined(__x86_64__)
14 #include "immintrin.h"
19 #if defined(MPLEX_USE_INTRINSICS)
21 #if defined(__AVX__) || defined(__AVX512F__)
23 #define MPLEX_INTRINSICS
27 #if defined(__AVX512F__)
29 typedef __m512 IntrVec_t;
30 #define MPLEX_INTRINSICS_WIDTH_BYTES 64
31 #define MPLEX_INTRINSICS_WIDTH_BITS 512
32 #define AVX512_INTRINSICS
33 #define GATHER_INTRINSICS
34 #define GATHER_IDX_LOAD(name, arr) __m512i name = _mm512_load_epi32(arr);
36 #define LD(a, i) _mm512_load_ps(&a[i * N + n])
37 #define ST(a, i, r) _mm512_store_ps(&a[i * N + n], r)
38 #define ADD(a, b) _mm512_add_ps(a, b)
39 #define MUL(a, b) _mm512_mul_ps(a, b)
40 #define FMA(a, b, v) _mm512_fmadd_ps(a, b, v)
42 #elif defined(__AVX2__) && defined(__FMA__)
44 typedef __m256 IntrVec_t;
45 #define MPLEX_INTRINSICS_WIDTH_BYTES 32
46 #define MPLEX_INTRINSICS_WIDTH_BITS 256
47 #define AVX2_INTRINSICS
48 #define GATHER_INTRINSICS
50 #define GATHER_IDX_LOAD(name, arr) __m256i name = _mm256_load_si256(reinterpret_cast<const __m256i *>(arr));
52 #define LD(a, i) _mm256_load_ps(&a[i * N + n])
53 #define ST(a, i, r) _mm256_store_ps(&a[i * N + n], r)
54 #define ADD(a, b) _mm256_add_ps(a, b)
55 #define MUL(a, b) _mm256_mul_ps(a, b)
56 #define FMA(a, b, v) _mm256_fmadd_ps(a, b, v)
58 #elif defined(__AVX__)
60 typedef __m256 IntrVec_t;
61 #define MPLEX_INTRINSICS_WIDTH_BYTES 32
62 #define MPLEX_INTRINSICS_WIDTH_BITS 256
63 #define AVX_INTRINSICS
65 #define LD(a, i) _mm256_load_ps(&a[i * N + n])
66 #define ST(a, i, r) _mm256_store_ps(&a[i * N + n], r)
67 #define ADD(a, b) _mm256_add_ps(a, b)
68 #define MUL(a, b) _mm256_mul_ps(a, b)
70 inline __m256 FMA(
const __m256 &
a,
const __m256 &
b,
const __m256 &
v) {
71 __m256
temp = _mm256_mul_ps(a, b);
72 return _mm256_add_ps(temp, v);
79 #ifdef __INTEL_COMPILER
80 #define ASSUME_ALIGNED(a, b) __assume_aligned(a, b)
82 #define ASSUME_ALIGNED(a, b) a = static_cast<decltype(a)>(__builtin_assume_aligned(a, b))
void align_check(const char *pref, void *adr)