Skip to content

Commit

Permalink
restructured cfftx to be rfftx and then modified cfftx...EJB
Browse files Browse the repository at this point in the history
  • Loading branch information
ebylaska committed Dec 7, 2023
1 parent 3db0caf commit 83a0181
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 96 deletions.
8 changes: 4 additions & 4 deletions Nwpw/nwpwlib/D3dB/d3db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3182,7 +3182,7 @@ void d3db::cr_fft3d(double *a)
indx0 += nxhy2;
}

mygdevice.batch_cfftx_tmpx(fft_tag,false, nx, ny * nq, n2ft3d, a, tmpx);
mygdevice.batch_rfftx_tmpx(fft_tag,false, nx, ny * nq, n2ft3d, a, tmpx);
}

/*************************
Expand Down Expand Up @@ -3211,7 +3211,7 @@ void d3db::cr_fft3d(double *a)
*** do fft along kx dimension ***
*** A(nx,ny,nz) <- fft1d^(-1)[A(kx,ny,nz)] ***
************************************************/
mygdevice.batch_cfftx_tmpx(fft_tag,false, nx, nq1, n2ft3d, a, tmpx);
mygdevice.batch_rfftx_tmpx(fft_tag,false, nx, nq1, n2ft3d, a, tmpx);

zeroend_fftb(nx, nq1, 1, 1, a);
if (n2ft3d_map < n2ft3d)
Expand Down Expand Up @@ -3252,7 +3252,7 @@ void d3db::rc_fft3d(double *a)
*** do fft along nx dimension ***
*** A(kx,ny,nz) <- fft1d[A(nx,ny,nz)] ***
********************************************/
mygdevice.batch_cfftx_tmpx(fft_tag,true, nx, ny*nq, n2ft3d, a, tmpx);
mygdevice.batch_rfftx_tmpx(fft_tag,true, nx, ny*nq, n2ft3d, a, tmpx);

/********************************************
*** do fft along ny dimension ***
Expand Down Expand Up @@ -3368,7 +3368,7 @@ void d3db::rc_fft3d(double *a)
*** do fft along nx dimension ***
*** A(kx,ny,nz) <- fft1d[A(nx,ny,nz)] ***
********************************************/
mygdevice.batch_cfftx_tmpx(fft_tag,true, nx, nq1, n2ft3d, a, tmpx);
mygdevice.batch_rfftx_tmpx(fft_tag,true, nx, nq1, n2ft3d, a, tmpx);

c_transpose_ijk(0, a, tmp2, tmp3);

Expand Down
21 changes: 21 additions & 0 deletions Nwpw/nwpwlib/device/gdevice2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,27 @@ void gdevice2::batch_fft_pipeline_mem_init(const int nstages, const int n2ft3d)
}



void gdevice2::batch_rfftx_tmpx(const int tag,bool forward, int nx, int nq, int n2ft3d,
double *a, double *tmpx) {
#if defined(NWPW_CUDA) || defined(NWPW_HIP)
if (mygdevice2->hasgpu)
mygdevice2->batch_rfftx(tag,forward, nx, nq, n2ft3d, a);
#else
mygdevice2->batch_rfftx_tmpx(forward, nx, nq, n2ft3d, a, tmpx);
#endif
}

void gdevice2::batch_rfftx_stages_tmpx(const int stage, const int tag,bool forward, int nx, int nq, int n2ft3d,
double *a, double *tmpx, int da) {
#if defined(NWPW_CUDA) || defined(NWPW_HIP)
if (mygdevice2->hasgpu)
mygdevice2->batch_rfftx_stages(stage,tag,forward, nx, nq, n2ft3d, a,da);
#endif
}



void gdevice2::batch_cfftx_tmpx(const int tag,bool forward, int nx, int nq, int n2ft3d,
double *a, double *tmpx) {
#if defined(NWPW_CUDA) || defined(NWPW_HIP)
Expand Down
2 changes: 2 additions & 0 deletions Nwpw/nwpwlib/device/gdevice2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ Gdevices *mygdevice2;

void batch_fft_pipeline_mem_init(const int,const int);

void batch_rfftx_tmpx(const int, bool, int, int, int, double *, double *);
void batch_cfftx_tmpx(const int, bool, int, int, int, double *, double *);
void batch_cffty_tmpy(const int, bool, int, int, int, double *, double *);
void batch_cfftz_tmpz(const int, bool, int, int, int, double *, double *);

void batch_rfftx_stages_tmpx(const int,const int, bool, int, int, int, double *, double *,int);
void batch_cfftx_stages_tmpx(const int,const int, bool, int, int, int, double *, double *,int);
void batch_cffty_stages_tmpy(const int,const int, bool, int, int, int, double *, double *,int);
void batch_cfftz_stages_tmpz(const int,const int, bool, int, int, int, double *, double *,int);
Expand Down
29 changes: 26 additions & 3 deletions Nwpw/nwpwlib/device/gdevices.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class Gdevices {
}
}

void batch_cfftx_tmpx(bool forward, int nx, int nq, int n2ft3d, double *a, double *tmpx)
void batch_rfftx_tmpx(bool forward, int nx, int nq, int n2ft3d, double *a, double *tmpx)
{
int nxh2 = nx + 2;
if (forward)
Expand Down Expand Up @@ -258,6 +258,29 @@ class Gdevices {
}
}


void batch_cfftx_tmpx(bool forward, int nx, int nq, int n2ft3d, double *a, double *tmpx)
{
if (forward)
{
int indx = 0;
for (auto q=0; q<nq; ++q)
{
dcfftf_(&nx, a + indx, tmpx);
indx += (2*nx);
}
}
else
{
int indx = 0;
for (auto q=0; q<nq; ++q)
{
dcfftb_(&nx, a + indx, tmpx);
indx += (2*nx);
}
}
}

void batch_cffty_tmpy(bool forward, int ny, int nq, int n2ft3d, double *a, double *tmpy)
{
if (forward)
Expand All @@ -266,7 +289,7 @@ class Gdevices {
for (auto q = 0; q < nq; ++q)
{
dcfftf_(&ny, a + indx, tmpy);
indx += (2 * ny);
indx += (2*ny);
}
}
else
Expand All @@ -275,7 +298,7 @@ class Gdevices {
for (auto q = 0; q < nq; ++q)
{
dcfftb_(&ny, a + indx, tmpy);
indx += (2 * ny);
indx += (2*ny);
}
}
}
Expand Down
89 changes: 83 additions & 6 deletions Nwpw/nwpwlib/device/gdevices_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class Gdevices {

int fftcount = 0;
int nxfft[2],nyfft[2], nzfft[2];
cufftHandle forward_plan_x[2] = {0,0}, plan_y[2] = {0,0}, plan_z[2] = {0,0};
cufftHandle forward_plan_x[2] = {0,0}, plan_x[2]={0,0}, plan_y[2] = {0,0}, plan_z[2] = {0,0};
cufftHandle backward_plan_x[2] = {0,0};
int ifft_dev[15];
int ifft_n;
Expand Down Expand Up @@ -1013,7 +1013,10 @@ class Gdevices {
if (DEBUG_IO) std::cout << "Into batch_fft_init" << std::endl;
NWPW_CUFFT_ERROR(cufftPlan1d(&forward_plan_x[fftcount], nx, CUFFT_D2Z, nq1));
NWPW_CUFFT_ERROR(cufftPlan1d(&backward_plan_x[fftcount], nx, CUFFT_Z2D, nq1));


int x_inembed[] = {nx};
int x_onembed[] = {nx};
NWPW_CUFFT_ERROR(cufftPlanMany(&plan_x[fftcount], 1, &nx, x_inembed, 1, nx, x_onembed, 1, nx, CUFFT_Z2Z, nq1));
int y_inembed[] = {ny};
int y_onembed[] = {ny};
NWPW_CUFFT_ERROR(cufftPlanMany(&plan_y[fftcount], 1, &ny, y_inembed, 1, ny, y_onembed, 1, ny, CUFFT_Z2Z, nq2));
Expand Down Expand Up @@ -1054,6 +1057,7 @@ class Gdevices {
{
// free fft descriptors
NWPW_CUFFT_ERROR(cufftDestroy(forward_plan_x[tag]));
NWPW_CUFFT_ERROR(cufftDestroy(plan_x[tag]));
NWPW_CUFFT_ERROR(cufftDestroy(plan_y[tag]));
NWPW_CUFFT_ERROR(cufftDestroy(plan_z[tag]));
NWPW_CUFFT_ERROR(cufftDestroy(backward_plan_x[tag]));
Expand All @@ -1072,10 +1076,10 @@ class Gdevices {

/**************************************
* *
* batch_cfftx *
* batch_rfftx *
* *
**************************************/
void batch_cfftx(const int fft_indx, bool forward, int nx, int nq, int n2ft3d, double *a)
void batch_rfftx(const int fft_indx, bool forward, int nx, int nq, int n2ft3d, double *a)
{
int ia_dev = fetch_dev_mem_indx(((size_t)n2ft3d));
NWPW_CUDA_ERROR(cudaMemcpy(dev_mem[ia_dev], a, n2ft3d * sizeof(double), cudaMemcpyHostToDevice));
Expand All @@ -1098,10 +1102,10 @@ class Gdevices {

/**************************************
* *
* batch_cfftx_stages *
* batch_rfftx_stages *
* *
**************************************/
void batch_cfftx_stages(const int stage, const int fft_indx, bool forward, int nx, int nq, int n2ft3d, double *a, int da)
void batch_rfftx_stages(const int stage, const int fft_indx, bool forward, int nx, int nq, int n2ft3d, double *a, int da)
{
//int ia_dev = fetch_dev_mem_indx(((size_t) n2ft3d));
int ia_dev = ifft_dev[da];
Expand Down Expand Up @@ -1134,6 +1138,79 @@ class Gdevices {
inuse[ia_dev] = false;
}
}



/**************************************
* *
* batch_cfftx *
* *
**************************************/
void batch_cfftx(const int fft_indx, bool forward, int nx, int nq, int n2ft3d, double *a)
{
int ia_dev = fetch_dev_mem_indx(((size_t)n2ft3d));
NWPW_CUDA_ERROR(cudaMemcpy(dev_mem[ia_dev], a, n2ft3d * sizeof(double), cudaMemcpyHostToDevice));

if (forward) {
NWPW_CUFFT_ERROR(cufftExecZ2Z(
plan_x[fft_indx], reinterpret_cast<cufftDoubleComplex *>(dev_mem[ia_dev]),
reinterpret_cast<cufftDoubleComplex *>(dev_mem[ia_dev]),
CUFFT_FORWARD));
} else {
NWPW_CUFFT_ERROR(cufftExecZ2Z(
plan_x[fft_indx], reinterpret_cast<cufftDoubleComplex *>(dev_mem[ia_dev]),
reinterpret_cast<cufftDoubleComplex *>(dev_mem[ia_dev]),
CUFFT_INVERSE));
}

NWPW_CUDA_ERROR(cudaMemcpy(a, dev_mem[ia_dev], n2ft3d * sizeof(double), cudaMemcpyDeviceToHost));

inuse[ia_dev] = false;
}



/**************************************
* *
* batch_cfftx_stages *
* *
**************************************/
void batch_cfftx_stages(const int stage, const int fft_indx, bool forward, int nx, int nq, int n2ft3d, double *a, int da)
{
//int ia_dev = fetch_dev_mem_indx(((size_t)n2ft3d));
int ia_dev = ifft_dev[da];
if (stage==0)
{
inuse[ia_dev] = true;
NWPW_CUDA_ERROR(cudaMemcpyAsync(dev_mem[ia_dev],a,n2ft3d*sizeof(double),cudaMemcpyHostToDevice,stream[da]));
}
else if (stage==1)
{
//NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[da]));
if (forward) {
NWPW_CUFFT_ERROR(cufftExecZ2Z(plan_x[fft_indx],
reinterpret_cast<cufftDoubleComplex *>(dev_mem[ia_dev]),
reinterpret_cast<cufftDoubleComplex *>(dev_mem[ia_dev]),
CUFFT_FORWARD));
} else {
NWPW_CUFFT_ERROR(cufftExecZ2Z(plan_x[fft_indx],
reinterpret_cast<cufftDoubleComplex *>(dev_mem[ia_dev]),
reinterpret_cast<cufftDoubleComplex *>(dev_mem[ia_dev]),
CUFFT_INVERSE));
}
NWPW_CUDA_ERROR(cudaMemcpyAsync(a,dev_mem[ia_dev],n2ft3d*sizeof(double),cudaMemcpyDeviceToHost,stream[da]));
}
else if (stage==2)
{
NWPW_CUDA_ERROR(cudaStreamSynchronize(stream[da]));
inuse[ia_dev] = false;
}
}







/**************************************
Expand Down
Loading

0 comments on commit 83a0181

Please sign in to comment.