1 #include <immintrin.h>
2
3 void cdot(
4 int simd, int n,
5 const float *a_r, const float *a_i,
6 const float *b_r, const float *b_i,
7 float *x_r, float *x_i)
8 {
9 float sum_r = 0;
10 float sum_i = 0;
11
12 if (simd == 0) {
13 // no SIMD
14 for (int i = 0; i < n; i++) {
15 sum_r += (a_r[i] * b_r[i]) - (a_i[i] * b_i[i]);
16 sum_i += (a_r[i] * b_i[i]) + (a_i[i] * b_r[i]);
17 }
18 }
19 else if (simd == 1) {
20 // SSE
21 #ifdef _WIN32
22 __declspec(align(16)) float sumr[4], sumi[4];
23 #else
24 __attribute__((aligned(16))) float sumr[4], sumi[4];
25 #endif
26 __m128 sr, si, v1r, v1i, v2r, v2i;
27 sr = _mm_setzero_ps();
28 si = _mm_setzero_ps();
29 for (int i = 0; i < n; i += 4) {
30 v1r = _mm_load_ps(a_r + i);
31 v1i = _mm_load_ps(a_i + i);
32 v2r = _mm_load_ps(b_r + i);
33 v2i = _mm_load_ps(b_i + i);
34 sr = _mm_add_ps(sr, _mm_sub_ps(_mm_mul_ps(v1r, v2r), _mm_mul_ps(v1i, v2i)));
35 si = _mm_add_ps(si, _mm_add_ps(_mm_mul_ps(v1r, v2i), _mm_mul_ps(v1i, v2r)));
36 }
37 _mm_store_ps(sumr, sr);
38 _mm_store_ps(sumi, si);
39 sum_r = sumr[0] + sumr[1] + sumr[2] + sumr[3];
40 sum_i = sumi[0] + sumi[1] + sumi[2] + sumi[3];
41 }
42 else if (simd == 2) {
43 // AVX
44 #ifdef _WIN32
45 __declspec(align(32)) float sumr[8], sumi[8];
46 #else
47 __attribute__((aligned(32))) float sumr[8], sumi[8];
48 #endif
49 __m256 sr, si, v1r, v1i, v2r, v2i;
50 sr = _mm256_setzero_ps();
51 si = _mm256_setzero_ps();
52 for (int i = 0; i < n; i += 8) {
53 v1r = _mm256_load_ps(a_r + i);
54 v1i = _mm256_load_ps(a_i + i);
55 v2r = _mm256_load_ps(b_r + i);
56 v2i = _mm256_load_ps(b_i + i);
57 sr = _mm256_add_ps(sr, _mm256_sub_ps(_mm256_mul_ps(v1r, v2r), _mm256_mul_ps(v1i, v2i)));
58 si = _mm256_add_ps(si, _mm256_add_ps(_mm256_mul_ps(v1r, v2i), _mm256_mul_ps(v1i, v2r)));
59 }
60 _mm256_store_ps(sumr, sr);
61 _mm256_store_ps(sumi, si);
62 sum_r = sumr[0] + sumr[1] + sumr[2] + sumr[3]
63 + sumr[4] + sumr[5] + sumr[6] + sumr[7];
64 sum_i = sumi[0] + sumi[1] + sumi[2] + sumi[3]
65 + sumi[4] + sumi[5] + sumi[6] + sumi[7];
66 }
67
68 *x_r = sum_r;
69 *x_i = sum_i;
70 }