NAMD
Classes | Public Member Functions | List of all members
CudaPmeKSpaceCompute Class Reference

#include <CudaPmeSolverUtil.h>

Inheritance diagram for CudaPmeKSpaceCompute:
PmeKSpaceCompute

Public Member Functions

 CudaPmeKSpaceCompute (PmeGrid pmeGrid, const int permutation, const int jblock, const int kblock, double kappa, int deviceID, cudaStream_t stream)
 
 ~CudaPmeKSpaceCompute ()
 
void solve (Lattice &lattice, const bool doEnergy, const bool doVirial, float *data)
 
double getEnergy ()
 
void getVirial (double *virial)
 
void energyAndVirialSetCallback (CudaPmePencilXYZ *pencilPtr)
 
void energyAndVirialSetCallback (CudaPmePencilZ *pencilPtr)
 
- Public Member Functions inherited from PmeKSpaceCompute
 PmeKSpaceCompute (PmeGrid pmeGrid, const int permutation, const int jblock, const int kblock, double kappa)
 
virtual ~PmeKSpaceCompute ()
 

Additional Inherited Members

- Protected Attributes inherited from PmeKSpaceCompute
PmeGrid pmeGrid
 
double * bm1
 
double * bm2
 
double * bm3
 
double kappa
 
const int permutation
 
const int jblock
 
const int kblock
 
int size1
 
int size2
 
int size3
 
int j0
 
int k0
 

Detailed Description

Definition at line 59 of file CudaPmeSolverUtil.h.

Constructor & Destructor Documentation

CudaPmeKSpaceCompute::CudaPmeKSpaceCompute ( PmeGrid  pmeGrid,
const int  permutation,
const int  jblock,
const int  kblock,
double  kappa,
int  deviceID,
cudaStream_t  stream 
)

Definition at line 191 of file CudaPmeSolverUtil.C.

References PmeKSpaceCompute::bm1, PmeKSpaceCompute::bm2, PmeKSpaceCompute::bm3, cudaCheck, PmeGrid::K1, PmeGrid::K2, and PmeGrid::K3.

192  :
194  deviceID(deviceID), stream(stream) {
195 
196  cudaCheck(cudaSetDevice(deviceID));
197 
198  // Copy bm1 -> prefac_x on GPU memory
199  float *bm1f = new float[pmeGrid.K1];
200  float *bm2f = new float[pmeGrid.K2];
201  float *bm3f = new float[pmeGrid.K3];
202  for (int i=0;i < pmeGrid.K1;i++) bm1f[i] = (float)bm1[i];
203  for (int i=0;i < pmeGrid.K2;i++) bm2f[i] = (float)bm2[i];
204  for (int i=0;i < pmeGrid.K3;i++) bm3f[i] = (float)bm3[i];
205  allocate_device<float>(&d_bm1, pmeGrid.K1);
206  allocate_device<float>(&d_bm2, pmeGrid.K2);
207  allocate_device<float>(&d_bm3, pmeGrid.K3);
208  copy_HtoD_sync<float>(bm1f, d_bm1, pmeGrid.K1);
209  copy_HtoD_sync<float>(bm2f, d_bm2, pmeGrid.K2);
210  copy_HtoD_sync<float>(bm3f, d_bm3, pmeGrid.K3);
211  delete [] bm1f;
212  delete [] bm2f;
213  delete [] bm3f;
214  allocate_device<EnergyVirial>(&d_energyVirial, 1);
215  allocate_host<EnergyVirial>(&h_energyVirial, 1);
216  // cudaCheck(cudaEventCreateWithFlags(&copyEnergyVirialEvent, cudaEventDisableTiming));
217  cudaCheck(cudaEventCreate(&copyEnergyVirialEvent));
218  // ncall = 0;
219 }
int K2
Definition: PmeBase.h:18
int K1
Definition: PmeBase.h:18
__thread cudaStream_t stream
PmeKSpaceCompute(PmeGrid pmeGrid, const int permutation, const int jblock, const int kblock, double kappa)
int K3
Definition: PmeBase.h:18
const int permutation
#define cudaCheck(stmt)
Definition: CudaUtils.h:79
CudaPmeKSpaceCompute::~CudaPmeKSpaceCompute ( )

Definition at line 221 of file CudaPmeSolverUtil.C.

References cudaCheck.

221  {
222  cudaCheck(cudaSetDevice(deviceID));
223  deallocate_device<float>(&d_bm1);
224  deallocate_device<float>(&d_bm2);
225  deallocate_device<float>(&d_bm3);
226  deallocate_device<EnergyVirial>(&d_energyVirial);
227  deallocate_host<EnergyVirial>(&h_energyVirial);
228  cudaCheck(cudaEventDestroy(copyEnergyVirialEvent));
229 }
#define cudaCheck(stmt)
Definition: CudaUtils.h:79

Member Function Documentation

void CudaPmeKSpaceCompute::energyAndVirialSetCallback ( CudaPmePencilXYZ pencilPtr)

Definition at line 381 of file CudaPmeSolverUtil.C.

References CcdCallBacksReset(), and cudaCheck.

381  {
382  cudaCheck(cudaSetDevice(deviceID));
383  pencilXYZPtr = pencilPtr;
384  pencilZPtr = NULL;
385  checkCount = 0;
386  CcdCallBacksReset(0, CmiWallTimer());
387  // Set the call back at 0.1ms
388  CcdCallFnAfter(energyAndVirialCheck, this, 0.1);
389 }
void CcdCallBacksReset(void *ignored, double curWallTime)
#define cudaCheck(stmt)
Definition: CudaUtils.h:79
void CudaPmeKSpaceCompute::energyAndVirialSetCallback ( CudaPmePencilZ pencilPtr)

Definition at line 391 of file CudaPmeSolverUtil.C.

References CcdCallBacksReset(), and cudaCheck.

391  {
392  cudaCheck(cudaSetDevice(deviceID));
393  pencilXYZPtr = NULL;
394  pencilZPtr = pencilPtr;
395  checkCount = 0;
396  CcdCallBacksReset(0, CmiWallTimer());
397  // Set the call back at 0.1ms
398  CcdCallFnAfter(energyAndVirialCheck, this, 0.1);
399 }
void CcdCallBacksReset(void *ignored, double curWallTime)
#define cudaCheck(stmt)
Definition: CudaUtils.h:79
double CudaPmeKSpaceCompute::getEnergy ( )
virtual

Implements PmeKSpaceCompute.

Definition at line 401 of file CudaPmeSolverUtil.C.

401  {
402  return h_energyVirial->energy;
403 }
void CudaPmeKSpaceCompute::getVirial ( double *  virial)
virtual

Implements PmeKSpaceCompute.

Definition at line 405 of file CudaPmeSolverUtil.C.

References Perm_cX_Y_Z, Perm_Z_cX_Y, and PmeKSpaceCompute::permutation.

405  {
406  if (permutation == Perm_Z_cX_Y) {
407  // h_energyVirial->virial is storing ZZ, ZX, ZY, XX, XY, YY
408  virial[0] = h_energyVirial->virial[3];
409  virial[1] = h_energyVirial->virial[4];
410  virial[2] = h_energyVirial->virial[1];
411 
412  virial[3] = h_energyVirial->virial[4];
413  virial[4] = h_energyVirial->virial[5];
414  virial[5] = h_energyVirial->virial[2];
415 
416  virial[6] = h_energyVirial->virial[1];
417  virial[7] = h_energyVirial->virial[7];
418  virial[8] = h_energyVirial->virial[0];
419  } else if (permutation == Perm_cX_Y_Z) {
420  // h_energyVirial->virial is storing XX, XY, XZ, YY, YZ, ZZ
421  virial[0] = h_energyVirial->virial[0];
422  virial[1] = h_energyVirial->virial[1];
423  virial[2] = h_energyVirial->virial[2];
424 
425  virial[3] = h_energyVirial->virial[1];
426  virial[4] = h_energyVirial->virial[3];
427  virial[5] = h_energyVirial->virial[4];
428 
429  virial[6] = h_energyVirial->virial[2];
430  virial[7] = h_energyVirial->virial[4];
431  virial[8] = h_energyVirial->virial[5];
432  }
433 }
const int permutation
void CudaPmeKSpaceCompute::solve ( Lattice lattice,
const bool  doEnergy,
const bool  doVirial,
float *  data 
)
virtual

Implements PmeKSpaceCompute.

Definition at line 231 of file CudaPmeSolverUtil.C.

References Lattice::a(), Lattice::a_r(), Lattice::b(), Lattice::b_r(), Lattice::c(), Lattice::c_r(), cudaCheck, PmeKSpaceCompute::j0, PmeKSpaceCompute::k0, PmeGrid::K1, PmeGrid::K2, PmeGrid::K3, PmeKSpaceCompute::kappa, NAMD_bug(), Perm_cX_Y_Z, Perm_Z_cX_Y, PmeKSpaceCompute::permutation, PmeKSpaceCompute::pmeGrid, scalar_sum(), PmeKSpaceCompute::size1, PmeKSpaceCompute::size2, PmeKSpaceCompute::size3, Lattice::volume(), Vector::x, Vector::y, and Vector::z.

231  {
232 #if 0
233  // Check lattice to make sure it is updating for constant pressure
234  fprintf(stderr, "K-SPACE LATTICE %g %g %g %g %g %g %g %g %g\n",
235  lattice.a().x, lattice.a().y, lattice.a().z,
236  lattice.b().x, lattice.b().y, lattice.b().z,
237  lattice.c().x, lattice.c().y, lattice.c().z);
238 #endif
239  cudaCheck(cudaSetDevice(deviceID));
240 
241  const bool doEnergyVirial = (doEnergy || doVirial);
242 
243  int nfft1, nfft2, nfft3;
244  float *prefac1, *prefac2, *prefac3;
245 
246  BigReal volume = lattice.volume();
247  Vector a_r = lattice.a_r();
248  Vector b_r = lattice.b_r();
249  Vector c_r = lattice.c_r();
250  float recip1x, recip1y, recip1z;
251  float recip2x, recip2y, recip2z;
252  float recip3x, recip3y, recip3z;
253 
254  if (permutation == Perm_Z_cX_Y) {
255  // Z, X, Y
256  nfft1 = pmeGrid.K3;
257  nfft2 = pmeGrid.K1;
258  nfft3 = pmeGrid.K2;
259  prefac1 = d_bm3;
260  prefac2 = d_bm1;
261  prefac3 = d_bm2;
262  recip1x = c_r.z;
263  recip1y = c_r.x;
264  recip1z = c_r.y;
265  recip2x = a_r.z;
266  recip2y = a_r.x;
267  recip2z = a_r.y;
268  recip3x = b_r.z;
269  recip3y = b_r.x;
270  recip3z = b_r.y;
271  } else if (permutation == Perm_cX_Y_Z) {
272  // X, Y, Z
273  nfft1 = pmeGrid.K1;
274  nfft2 = pmeGrid.K2;
275  nfft3 = pmeGrid.K3;
276  prefac1 = d_bm1;
277  prefac2 = d_bm2;
278  prefac3 = d_bm3;
279  recip1x = a_r.x;
280  recip1y = a_r.y;
281  recip1z = a_r.z;
282  recip2x = b_r.x;
283  recip2y = b_r.y;
284  recip2z = b_r.z;
285  recip3x = c_r.x;
286  recip3y = c_r.y;
287  recip3z = c_r.z;
288  } else {
289  NAMD_bug("CudaPmeKSpaceCompute::solve, invalid permutation");
290  }
291 
292  // ncall++;
293  // if (ncall == 1) {
294  // char filename[256];
295  // sprintf(filename,"dataf_%d_%d.txt",jblock,kblock);
296  // writeComplexToDisk((float2*)data, size1*size2*size3, filename, stream);
297  // }
298 
299  // if (ncall == 1) {
300  // float2* h_data = new float2[size1*size2*size3];
301  // float2* d_data = (float2*)data;
302  // copy_DtoH<float2>(d_data, h_data, size1*size2*size3, stream);
303  // cudaCheck(cudaStreamSynchronize(stream));
304  // FILE *handle = fopen("dataf.txt", "w");
305  // for (int z=0;z < pmeGrid.K3;z++) {
306  // for (int y=0;y < pmeGrid.K2;y++) {
307  // for (int x=0;x < pmeGrid.K1/2+1;x++) {
308  // int i;
309  // if (permutation == Perm_cX_Y_Z) {
310  // i = x + y*size1 + z*size1*size2;
311  // } else {
312  // i = z + x*size1 + y*size1*size2;
313  // }
314  // fprintf(handle, "%f %f\n", h_data[i].x, h_data[i].y);
315  // }
316  // }
317  // }
318  // fclose(handle);
319  // delete [] h_data;
320  // }
321 
322  // Clear energy and virial array if needed
323  if (doEnergyVirial) clear_device_array<EnergyVirial>(d_energyVirial, 1, stream);
324 
325  scalar_sum(permutation == Perm_cX_Y_Z, nfft1, nfft2, nfft3, size1, size2, size3, kappa,
326  recip1x, recip1y, recip1z, recip2x, recip2y, recip2z, recip3x, recip3y, recip3z,
327  volume, prefac1, prefac2, prefac3, j0, k0, doEnergyVirial,
328  &d_energyVirial->energy, d_energyVirial->virial, (float2*)data,
329  stream);
330 
331  // Copy energy and virial to host if needed
332  if (doEnergyVirial) {
333  copy_DtoH<EnergyVirial>(d_energyVirial, h_energyVirial, 1, stream);
334  cudaCheck(cudaEventRecord(copyEnergyVirialEvent, stream));
335  // cudaCheck(cudaStreamSynchronize(stream));
336  }
337 
338 }
Vector a_r() const
Definition: Lattice.h:268
void scalar_sum(const bool orderXYZ, const int nfft1, const int nfft2, const int nfft3, const int size1, const int size2, const int size3, const double kappa, const float recip1x, const float recip1y, const float recip1z, const float recip2x, const float recip2y, const float recip2z, const float recip3x, const float recip3y, const float recip3z, const double volume, const float *prefac1, const float *prefac2, const float *prefac3, const int k2_00, const int k3_00, const bool doEnergyVirial, double *energy, double *virial, float2 *data, cudaStream_t stream)
Definition: Vector.h:64
int K2
Definition: PmeBase.h:18
int K1
Definition: PmeBase.h:18
Vector c_r() const
Definition: Lattice.h:270
BigReal z
Definition: Vector.h:66
Vector b_r() const
Definition: Lattice.h:269
__thread cudaStream_t stream
void NAMD_bug(const char *err_msg)
Definition: common.C:123
BigReal x
Definition: Vector.h:66
BigReal volume(void) const
Definition: Lattice.h:277
int K3
Definition: PmeBase.h:18
const int permutation
BigReal y
Definition: Vector.h:66
Vector b() const
Definition: Lattice.h:253
#define cudaCheck(stmt)
Definition: CudaUtils.h:79
Vector a() const
Definition: Lattice.h:252
Vector c() const
Definition: Lattice.h:254
double BigReal
Definition: common.h:112

The documentation for this class was generated from the following files: