00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00021
00022
00023 #if defined(VMDCPUDISPATCH) && defined(VMDUSEAVX512)
00024
00025 #include <immintrin.h>
00026
00027 #include <math.h>
00028 #include <stdio.h>
00029 #include "Orbital.h"
00030 #include "DrawMolecule.h"
00031 #include "utilities.h"
00032 #include "Inform.h"
00033 #include "WKFThreads.h"
00034 #include "WKFUtils.h"
00035 #include "ProfileHooks.h"
00036
00037 #define ANGS_TO_BOHR 1.88972612478289694072f
00038
00039 #if defined(__GNUC__) && ! defined(__INTEL_COMPILER)
00040 #define __align(X) __attribute__((aligned(X) ))
00041 #else
00042 #define __align(X) __declspec(align(X) )
00043 #endif
00044
00045 #define MLOG2EF -1.44269504088896f
00046
00047 #if 0
00048 static void print_mm512_ps(__m512 v) {
00049 __attribute__((aligned(64))) float tmp[16];
00050 _mm512_storeu_ps(&tmp[0], v);
00051
00052 printf("mm512: ");
00053 int i;
00054 for (i=0; i<16; i++)
00055 printf("%g ", tmp[i]);
00056 printf("\n");
00057 }
00058 #endif
00059
00060
00061
00062
00063
00064
00065 int evaluate_grid_avx512er(int numatoms,
00066 const float *wave_f, const float *basis_array,
00067 const float *atompos,
00068 const int *atom_basis,
00069 const int *num_shells_per_atom,
00070 const int *num_prim_per_shell,
00071 const int *shell_types,
00072 const int *numvoxels,
00073 float voxelsize,
00074 const float *origin,
00075 int density,
00076 float * orbitalgrid) {
00077 if (!orbitalgrid)
00078 return -1;
00079
00080 int nx, ny, nz;
00081 __attribute__((aligned(64))) float sxdelta[16];
00082 for (nx=0; nx<16; nx++)
00083 sxdelta[nx] = ((float) nx) * voxelsize * ANGS_TO_BOHR;
00084
00085
00086
00087 int numgridxy = numvoxels[0]*numvoxels[1];
00088 for (nz=0; nz<numvoxels[2]; nz++) {
00089 float grid_x, grid_y, grid_z;
00090 grid_z = origin[2] + nz * voxelsize;
00091 for (ny=0; ny<numvoxels[1]; ny++) {
00092 grid_y = origin[1] + ny * voxelsize;
00093 int gaddrzy = ny*numvoxels[0] + nz*numgridxy;
00094 for (nx=0; nx<numvoxels[0]; nx+=16) {
00095 grid_x = origin[0] + nx * voxelsize;
00096
00097
00098
00099 int at;
00100 int prim, shell;
00101
00102
00103 __m512 value = _mm512_set1_ps(0.0f);
00104
00105
00106 int ifunc = 0;
00107 int shell_counter = 0;
00108
00109
00110 for (at=0; at<numatoms; at++) {
00111 int maxshell = num_shells_per_atom[at];
00112 int prim_counter = atom_basis[at];
00113
00114
00115 float sxdist = (grid_x - atompos[3*at ])*ANGS_TO_BOHR;
00116 float sydist = (grid_y - atompos[3*at+1])*ANGS_TO_BOHR;
00117 float szdist = (grid_z - atompos[3*at+2])*ANGS_TO_BOHR;
00118
00119 float sydist2 = sydist*sydist;
00120 float szdist2 = szdist*szdist;
00121 float yzdist2 = sydist2 + szdist2;
00122
00123 __m512 xdelta = _mm512_load_ps(&sxdelta[0]);
00124 __m512 xdist = _mm512_set1_ps(sxdist);
00125 xdist = _mm512_add_ps(xdist, xdelta);
00126 __m512 ydist = _mm512_set1_ps(sydist);
00127 __m512 zdist = _mm512_set1_ps(szdist);
00128 __m512 xdist2 = _mm512_mul_ps(xdist, xdist);
00129 __m512 ydist2 = _mm512_mul_ps(ydist, ydist);
00130 __m512 zdist2 = _mm512_mul_ps(zdist, zdist);
00131 __m512 dist2 = _mm512_set1_ps(yzdist2);
00132 dist2 = _mm512_add_ps(dist2, xdist2);
00133
00134
00135
00136
00137
00138
00139 for (shell=0; shell < maxshell; shell++) {
00140 __m512 contracted_gto = _mm512_set1_ps(0.0f);
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150 int maxprim = num_prim_per_shell[shell_counter];
00151 int shelltype = shell_types[shell_counter];
00152 for (prim=0; prim<maxprim; prim++) {
00153
00154 float exponent = -basis_array[prim_counter ];
00155 float contract_coeff = basis_array[prim_counter + 1];
00156
00157
00158 #if 1
00159 __m512 expval = _mm512_mul_ps(_mm512_set1_ps(-exponent * MLOG2EF), dist2);
00160
00161 __m512 retval = _mm512_exp2a23_ps(expval);
00162 contracted_gto = _mm512_fmadd_ps(_mm512_set1_ps(contract_coeff), retval, contracted_gto);
00163 #else
00164 __m512 expval = _mm512_mul_ps(_mm512_set1_ps(-exponent), dist2);
00165
00166 expval = _mm512_mul_ps(expval, _mm512_set1_ps(MLOG2EF));
00167 __m512 retval = _mm512_exp2a23_ps(expval);
00168 __m512 ctmp = _mm512_mul_ps(_mm512_set1_ps(contract_coeff), retval);
00169 contracted_gto = _mm512_add_ps(contracted_gto, ctmp);
00170 #endif
00171
00172 prim_counter += 2;
00173 }
00174
00175
00176 __m512 tmpshell = _mm512_set1_ps(0.0f);
00177 switch (shelltype) {
00178
00179 case S_SHELL:
00180 value = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), contracted_gto, value);
00181 break;
00182
00183 case P_SHELL:
00184 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), xdist, tmpshell);
00185 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), ydist, tmpshell);
00186 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), zdist, tmpshell);
00187 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00188 break;
00189
00190 case D_SHELL:
00191 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), xdist2, tmpshell);
00192 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist, ydist), tmpshell);
00193 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), ydist2, tmpshell);
00194 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist, zdist), tmpshell);
00195 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist, zdist), tmpshell);
00196 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), zdist2, tmpshell);
00197 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00198 break;
00199
00200 case F_SHELL:
00201 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, xdist), tmpshell);
00202 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, ydist), tmpshell);
00203 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, xdist), tmpshell);
00204 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, ydist), tmpshell);
00205 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(xdist2, zdist), tmpshell);
00206 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(_mm512_mul_ps(xdist, ydist), zdist), tmpshell);
00207 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(ydist2, zdist), tmpshell);
00208 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, xdist), tmpshell);
00209 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, ydist), tmpshell);
00210 tmpshell = _mm512_fmadd_ps(_mm512_set1_ps(wave_f[ifunc++]), _mm512_mul_ps(zdist2, zdist), tmpshell);
00211 value = _mm512_fmadd_ps(tmpshell, contracted_gto, value);
00212 break;
00213
00214
00215 #if 0
00216 default:
00217
00218 int i, j;
00219 float xdp, ydp, zdp;
00220 float xdiv = 1.0f / xdist;
00221 for (j=0, zdp=1.0f; j<=shelltype; j++, zdp*=zdist) {
00222 int imax = shelltype - j;
00223 for (i=0, ydp=1.0f, xdp=pow(xdist, imax); i<=imax; i++, ydp*=ydist, xdp*=xdiv) {
00224 tmpshell += wave_f[ifunc++] * xdp * ydp * zdp;
00225 }
00226 }
00227 value += tmpshell * contracted_gto;
00228 #endif
00229 }
00230
00231 shell_counter++;
00232 }
00233 }
00234
00235
00236 if (density) {
00237 __mmask16 mask = _mm512_cmplt_ps_mask(value, _mm512_set1_ps(0.0f));
00238 __m512 sqdensity = _mm512_mul_ps(value, value);
00239 __m512 orbdensity = _mm512_mask_mul_ps(sqdensity, mask, sqdensity,
00240 _mm512_set1_ps(-1.0f));
00241 _mm512_storeu_ps(&orbitalgrid[gaddrzy + nx], orbdensity);
00242 } else {
00243 _mm512_storeu_ps(&orbitalgrid[gaddrzy + nx], value);
00244 }
00245 }
00246 }
00247 }
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257 _mm256_zeroupper();
00258
00259 return 0;
00260 }
00261
00262 #endif
00263
00264