1 #ifndef AVXTILESKERNEL_H
2 #define AVXTILESKERNEL_H
9 __forceinline
void getOmpSimdTableI(
const __m512 &r_1, __m512i &table_int,
10 __m512 &tableDiff, __m512 &rTableDiff) {
12 const __m512 maxv = _mm512_set1_ps(KNL_TABLE_MAX_R_1);
13 const __mmask16 tmask = _mm512_cmplt_ps_mask(maxv, r_1);
14 const __m512 table_r_1 = _mm512_mask_mov_ps(r_1, tmask, maxv);
16 const __m512 table_f = _mm512_mul_ps(_mm512_set1_ps(KNL_TABLE_FACTOR-2),
19 table_int = _mm512_cvttps_epi32(table_f);
21 tableDiff = _mm512_sub_ps(table_f, _mm512_cvtepi32_ps(table_int));
23 rTableDiff = _mm512_sub_ps(_mm512_set1_ps(1.f), tableDiff);
28 template<
bool doEnergy>
29 __forceinline
void forceEnergySlow512(
const __mmask16 r2mask,
31 const float * __restrict__ slowTable,
32 const float * __restrict__ slowEtable,
33 const __m512i &table_int,
34 const __m512 &tableDiff,
35 const __m512 &rTableDiff,
36 __m512 &forceSlow, __m512 &energySlow) {
38 const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
39 r2mask, table_int, slowTable, _MM_SCALE_4);
40 const __m512i table_int2 = _mm512_shuffle_i32x4(table_int, table_int, 238);
41 const __mmask16 r2mask2 = r2mask >> 8;
42 const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
43 r2mask2, table_int2, slowTable, _MM_SCALE_4);
44 const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
46 const __m512 tabSlowP1 = _mm512_permutex2var_ps(t0, t4, t1);
47 const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
49 const __m512 tabSlow = _mm512_permutex2var_ps(t0, t6, t1);
52 forceSlow = _mm512_mul_ps(kqq, _mm512_fnmsub_ps(tabSlow, rTableDiff,
53 _mm512_mul_ps(tabSlowP1, tableDiff)));
56 const __m512 t10 = (__m512)_mm512_mask_i32logather_pd(
57 _mm512_undefined_pd(), r2mask, table_int, slowEtable, _MM_SCALE_4);
58 const __m512 t11 = (__m512)_mm512_mask_i32logather_pd(
59 _mm512_undefined_pd(), r2mask2, table_int2, slowEtable, _MM_SCALE_4);
60 const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
62 const __m512 tabSlowEp1 = _mm512_permutex2var_ps(t10, t4, t11);
63 const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
65 const __m512 tabSlowE = _mm512_permutex2var_ps(t10, t6, t11);
68 const __m512 eSlow = _mm512_fmadd_ps(tabSlowE, rTableDiff,
69 _mm512_mul_ps(tabSlowEp1, tableDiff));
71 energySlow = _mm512_mask_mov_ps(energySlow, r2mask,
72 _mm512_fnmadd_ps(kqq, eSlow, energySlow));
78 template<
bool doEnergy,
bool doSlow,
int iMode>
79 __forceinline
void forceEnergyInterp2(
const __m512 &r2,
const __m512 &kqq,
80 const __m512i &type_i,
81 const __m512i &type_j, __m512 &force,
82 __m512 &forceSlow, __m512 &energyVdw,
83 __m512 &energyElec, __m512 &energySlow,
84 const __mmask16 r2mask,
85 const float scaling,
const float c1,
86 const float c3,
const float switchOn2,
89 const float cutUnder3,
90 const float * __restrict__ fastTable,
91 const float * __restrict__ energyTable,
92 const float * __restrict__ slowTable,
93 const float * __restrict__ slowEtable,
94 const float * __restrict__ ljTable,
98 const __m512i lj_i = _mm512_slli_epi32(_mm512_add_epi32(
99 _mm512_mullo_epi32(type_i,_mm512_set1_epi32(ljWidth)), type_j), 1);
101 const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
102 r2mask, lj_i, ljTable, _MM_SCALE_8);
103 const __m512i lj_i2 = _mm512_shuffle_i32x4(lj_i, lj_i, 238);
104 const __mmask16 r2mask2 = r2mask >> 8;
105 const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
106 r2mask2, lj_i2, ljTable, _MM_SCALE_8);
107 const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
109 const __m512
B = _mm512_permutex2var_ps(t0, t4, t1);
110 const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
112 const __m512
A = _mm512_permutex2var_ps(t0, t6, t1);
115 const __m512 r_1 = _mm512_invsqrt_ps(r2);
116 __m512 tableDiff, rTableDiff;
118 if (iMode == 3 || doSlow)
119 getOmpSimdTableI(r_1, table_int, tableDiff, rTableDiff);
122 __m512 tabFast, tabFastp1, tabEnergy, tabEnergyp1;
123 __m512 tabSlow, tabSlowp1, tabSlowE, tabSlowEp1;
125 const __m512 t0 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
126 r2mask, table_int, fastTable, _MM_SCALE_4);
127 const __m512i table_int2 = _mm512_shuffle_i32x4(table_int, table_int, 238);
128 const __mmask16 r2mask2 = r2mask >> 8;
129 const __m512 t1 = (__m512)_mm512_mask_i32logather_pd(_mm512_undefined_pd(),
130 r2mask2, table_int2, fastTable, _MM_SCALE_4);
131 const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
133 tabFastp1 = _mm512_permutex2var_ps(t0, t4, t1);
134 const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
136 tabFast = _mm512_permutex2var_ps(t0, t6, t1);
139 const __m512 t10 = (__m512)_mm512_mask_i32logather_pd(
140 _mm512_undefined_pd(), r2mask, table_int, energyTable, _MM_SCALE_4);
141 const __m512 t11 = (__m512)_mm512_mask_i32logather_pd(
142 _mm512_undefined_pd(), r2mask2, table_int2, energyTable, _MM_SCALE_4);
143 const __m512i t4 = _mm512_set_epi32(31,29,27,25,23,21,19,17,
145 tabEnergyp1 = _mm512_permutex2var_ps(t10, t4, t11);
146 const __m512i t6 = _mm512_set_epi32(30,28,26,24,22,20,18,16,
148 tabEnergy = _mm512_permutex2var_ps(t10, t6, t11);
153 const __m512 r_2 = _mm512_mul_ps(r_1, r_1);
155 const __m512 r_6 = _mm512_mul_ps(r_2, _mm512_mul_ps(r_2, r_2));
157 const __m512 r_12 = _mm512_mul_ps(r_6, r_6);
159 const __m512 c2 = _mm512_sub_ps(_mm512_set1_ps(cutoff2), r2);
161 const __m512 c4 = _mm512_mul_ps(_mm512_fnmadd_ps(_mm512_set1_ps(2.f), c2,
162 _mm512_set1_ps(c3)), c2);
164 const __mmask16 switchMask = _mm512_cmplt_ps_mask(_mm512_set1_ps(switchOn2),
166 const __m512 switchVal = _mm512_mask_mov_ps(_mm512_set1_ps(1.f), switchMask,
167 _mm512_mul_ps(c2, _mm512_mul_ps(c4, _mm512_set1_ps(c1))));
169 const __m512 dSwitchVal = _mm512_mask_mov_ps(_mm512_setzero_ps(),
170 switchMask, _mm512_mul_ps(_mm512_set1_ps(2.f),
171 _mm512_mul_ps(_mm512_set1_ps(c1), _mm512_fmsub_ps(c2,c2,c4))));
173 const __m512 r2SwitchVal = _mm512_mul_ps(switchVal, r_2);
175 const __m512 vdwAgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
176 _mm512_set1_ps(6.f), r2SwitchVal, dSwitchVal), r_12);
178 const __m512 vdwBgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
179 _mm512_set1_ps(3.f), r2SwitchVal, dSwitchVal), r_6);
181 const __m512 vdwB = _mm512_mul_ps(_mm512_set1_ps(scaling),
182 _mm512_fmsub_ps(A,vdwAgradient, _mm512_mul_ps(B, vdwBgradient)));
186 if (iMode == 2) ffast = _mm512_mul_ps(kqq, _mm512_fmadd_ps(r_2,r_1,
187 _mm512_set1_ps(mInvCut3)));
189 else ffast = _mm512_mul_ps(kqq, _mm512_fmadd_ps(tabFast, rTableDiff,
190 _mm512_mul_ps(tabFastp1, tableDiff)));
196 efast = _mm512_fmadd_ps(r2,_mm512_set1_ps(mInvCut3),
197 _mm512_set1_ps(cutUnder3));
199 efast = _mm512_fmsub_ps(efast,_mm512_set1_ps(0.5f), r_1);
202 efast = _mm512_fmadd_ps(tabEnergy, rTableDiff,
203 _mm512_mul_ps(tabEnergyp1, tableDiff));
206 const __m512 vdwTerm = _mm512_fmsub_ps(A, r_12, _mm512_mul_ps(B, r_6));
208 energyVdw = _mm512_mask_mov_ps(energyVdw, r2mask,
209 _mm512_fmadd_ps(switchVal, vdwTerm, energyVdw));
211 energyElec = _mm512_mask_mov_ps(energyElec, r2mask,
212 _mm512_fnmadd_ps(kqq, efast, energyElec));
216 force = _mm512_sub_ps(vdwB, ffast);
218 forceEnergySlow512<doEnergy>(r2mask, kqq, slowTable, slowEtable, table_int,
219 tableDiff, rTableDiff, forceSlow, energySlow);
225 template<
bool doEnergy,
bool doSlow>
226 __forceinline
void forceEnergyInterp1(
const __m512 &r2,
const __m512 &kqq,
227 __m512 &force, __m512 &forceSlow,
228 __m512 &energyVdw, __m512 &energyElec,
230 const __mmask16 r2mask,
const float c1,
231 const float c3,
const float switchOn2,
233 const float mInvCut3,
234 const float cutUnder3,
235 const float * __restrict__ slowTable,
236 const float * __restrict__ slowEtable,
237 const __m512 &eps4i,
const __m512 &eps4j,
238 const __m512 &sigmaI,
239 const __m512 &sigmaJ) {
242 const __m512 eps_ij = _mm512_sqrt_ps(_mm512_mul_ps(eps4i, eps4j));
244 __m512 sigma_ij = _mm512_mul_ps(_mm512_set1_ps(0.5f),
245 _mm512_add_ps(sigmaI, sigmaJ));
247 sigma_ij = _mm512_mul_ps(sigma_ij, _mm512_mul_ps(sigma_ij, sigma_ij));
249 sigma_ij = _mm512_mul_ps(sigma_ij, sigma_ij);
251 const __m512
B(_mm512_mul_ps(sigma_ij, eps_ij));
253 const __m512
A(_mm512_mul_ps(B, sigma_ij));
256 const __m512 r_1 = _mm512_invsqrt_ps(r2);
257 __m512 tableDiff, rTableDiff;;
260 getOmpSimdTableI(r_1, table_int, tableDiff, rTableDiff);
263 const __m512 r_2 = _mm512_mul_ps(r_1, r_1);
265 const __m512 r_6 = _mm512_mul_ps(r_2, _mm512_mul_ps(r_2, r_2));
267 const __m512 r_12 = _mm512_mul_ps(r_6, r_6);
269 const __m512 c2 = _mm512_sub_ps(_mm512_set1_ps(cutoff2), r2);
271 const __m512 c4 = _mm512_mul_ps(_mm512_fnmadd_ps(_mm512_set1_ps(2.f), c2,
272 _mm512_set1_ps(c3)), c2);
274 const __mmask16 switchMask = _mm512_cmplt_ps_mask(_mm512_set1_ps(switchOn2),
276 const __m512 switchVal = _mm512_mask_mov_ps(_mm512_set1_ps(1.f), switchMask,
277 _mm512_mul_ps(c2,_mm512_mul_ps(c4,_mm512_set1_ps(c1))));
279 const __m512 dSwitchVal = _mm512_mask_mov_ps(_mm512_setzero_ps(),
280 switchMask, _mm512_mul_ps(_mm512_set1_ps(2.f),
281 _mm512_mul_ps(_mm512_set1_ps(c1), _mm512_fmsub_ps(c2,c2,c4))));
283 const __m512 r2SwitchVal = _mm512_mul_ps(switchVal, r_2);
285 const __m512 vdwAgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
286 _mm512_set1_ps(6.f), r2SwitchVal, dSwitchVal), r_12);
288 const __m512 vdwBgradient = _mm512_mul_ps(_mm512_fnmadd_ps(
289 _mm512_set1_ps(3.f), r2SwitchVal, dSwitchVal), r_6);
291 const __m512 vdwB = _mm512_mul_ps(_mm512_set1_ps(2.f),
292 _mm512_fmsub_ps(A, vdwAgradient, _mm512_mul_ps(B, vdwBgradient)));
294 const __m512 ffast = _mm512_mul_ps(kqq, _mm512_fmadd_ps(r_2, r_1,
295 _mm512_set1_ps(mInvCut3)));
299 __m512 efast = _mm512_fmadd_ps(r2,_mm512_set1_ps(mInvCut3),
300 _mm512_set1_ps(cutUnder3));
302 efast = _mm512_fmsub_ps(efast, _mm512_set1_ps(0.5f), r_1);
304 const __m512 vdwTerm = _mm512_fmsub_ps(A, r_12, _mm512_mul_ps(B, r_6));
306 energyVdw = _mm512_mask_mov_ps(energyVdw, r2mask,
307 _mm512_fmadd_ps(switchVal, vdwTerm, energyVdw));
309 energyElec = _mm512_mask_mov_ps(energyElec, r2mask,
310 _mm512_fnmadd_ps(kqq, efast, energyElec));
314 force = _mm512_sub_ps(vdwB, ffast);
316 forceEnergySlow512<doEnergy>(r2mask, kqq, slowTable, slowEtable, table_int,
317 tableDiff, rTableDiff, forceSlow, energySlow);
320 #endif // NAMD_AVXTILES
321 #endif // AVXTILELISTS_H
__global__ void const int const TileList *__restrict__ TileExcl *__restrict__ const int *__restrict__ const int const float2 *__restrict__ cudaTextureObject_t const int *__restrict__ const float3 const float3 const float3 const float4 *__restrict__ const float cutoff2