Actual source code: cufft.cu

petsc-3.9.4 2018-09-11
Report Typos and Errors

  2: /*
  3:     Provides an interface to the CUFFT package.
  4:     Testing examples can be found in ~src/mat/examples/tests
  5: */

  7:  #include <petsc/private/matimpl.h>
  8: EXTERN_C_BEGIN
  9: #include <cuda.h>
 10: #include <cuda_runtime.h>
 11: #include <cufft.h>
 12: EXTERN_C_END

 14: typedef struct {
 15:   PetscInt     ndim;
 16:   PetscInt     *dim;
 17:   cufftHandle  p_forward, p_backward;
 18:   cufftComplex *devArray;
 19: } Mat_CUFFT;

 21: PetscErrorCode MatMult_SeqCUFFT(Mat A, Vec x, Vec y)
 22: {
 23:   Mat_CUFFT      *cufft    = (Mat_CUFFT*) A->data;
 24:   cufftComplex   *devArray = cufft->devArray;
 25:   PetscInt       ndim      = cufft->ndim, *dim = cufft->dim;
 26:   PetscScalar    *x_array, *y_array;
 27:   cufftResult    result;

 31:   VecGetArray(x, &x_array);
 32:   VecGetArray(y, &y_array);
 33:   if (!cufft->p_forward) {
 34:     cufftResult result;
 35:     /* create a plan, then execute it */
 36:     switch (ndim) {
 37:     case 1:
 38:       result = cufftPlan1d(&cufft->p_forward, dim[0], CUFFT_C2C, 1);CHKERRQ(result != CUFFT_SUCCESS);
 39:       break;
 40:     case 2:
 41:       result = cufftPlan2d(&cufft->p_forward, dim[0], dim[1], CUFFT_C2C);CHKERRQ(result != CUFFT_SUCCESS);
 42:       break;
 43:     case 3:
 44:       result = cufftPlan3d(&cufft->p_forward, dim[0], dim[1], dim[2], CUFFT_C2C);CHKERRQ(result != CUFFT_SUCCESS);
 45:       break;
 46:     default:
 47:       SETERRQ1(PETSC_COMM_SELF, PETSC_ERR_USER, "Cannot create plan for %d-dimensional transform", ndim);
 48:     }
 49:   }
 50:   /* transfer to GPU memory */
 51:   cudaMemcpy(devArray, x_array, sizeof(cufftComplex)*dim[ndim], cudaMemcpyHostToDevice);
 52:   /* execute transform */
 53:   result = cufftExecC2C(cufft->p_forward, devArray, devArray, CUFFT_FORWARD);CHKERRQ(result != CUFFT_SUCCESS);
 54:   /* transfer from GPU memory */
 55:   cudaMemcpy(y_array, devArray, sizeof(cufftComplex)*dim[ndim], cudaMemcpyDeviceToHost);
 56:   VecRestoreArray(y, &y_array);
 57:   VecRestoreArray(x, &x_array);
 58:   return(0);
 59: }

 61: PetscErrorCode MatMultTranspose_SeqCUFFT(Mat A, Vec x, Vec y)
 62: {
 63:   Mat_CUFFT      *cufft    = (Mat_CUFFT*) A->data;
 64:   cufftComplex   *devArray = cufft->devArray;
 65:   PetscInt       ndim      = cufft->ndim, *dim = cufft->dim;
 66:   PetscScalar    *x_array, *y_array;
 67:   cufftResult    result;

 71:   VecGetArray(x, &x_array);
 72:   VecGetArray(y, &y_array);
 73:   if (!cufft->p_backward) {
 74:     /* create a plan, then execute it */
 75:     switch (ndim) {
 76:     case 1:
 77:       result = cufftPlan1d(&cufft->p_backward, dim[0], CUFFT_C2C, 1);CHKERRQ(result != CUFFT_SUCCESS);
 78:       break;
 79:     case 2:
 80:       result = cufftPlan2d(&cufft->p_backward, dim[0], dim[1], CUFFT_C2C);CHKERRQ(result != CUFFT_SUCCESS);
 81:       break;
 82:     case 3:
 83:       result = cufftPlan3d(&cufft->p_backward, dim[0], dim[1], dim[2], CUFFT_C2C);CHKERRQ(result != CUFFT_SUCCESS);
 84:       break;
 85:     default:
 86:       SETERRQ1(PETSC_COMM_SELF, PETSC_ERR_USER, "Cannot create plan for %d-dimensional transform", ndim);
 87:     }
 88:   }
 89:   /* transfer to GPU memory */
 90:   cudaMemcpy(devArray, x_array, sizeof(cufftComplex)*dim[ndim], cudaMemcpyHostToDevice);
 91:   /* execute transform */
 92:   result = cufftExecC2C(cufft->p_forward, devArray, devArray, CUFFT_INVERSE);CHKERRQ(result != CUFFT_SUCCESS);
 93:   /* transfer from GPU memory */
 94:   cudaMemcpy(y_array, devArray, sizeof(cufftComplex)*dim[ndim], cudaMemcpyDeviceToHost);
 95:   VecRestoreArray(y, &y_array);
 96:   VecRestoreArray(x, &x_array);
 97:   return(0);
 98: }

100: PetscErrorCode MatDestroy_SeqCUFFT(Mat A)
101: {
102:   Mat_CUFFT      *cufft = (Mat_CUFFT*) A->data;
103:   cufftResult    result;

107:   PetscFree(cufft->dim);
108:   if (cufft->p_forward)  {result = cufftDestroy(cufft->p_forward);CHKERRQ(result != CUFFT_SUCCESS);}
109:   if (cufft->p_backward) {result = cufftDestroy(cufft->p_backward);CHKERRQ(result != CUFFT_SUCCESS);}
110:   cudaFree(cufft->devArray);
111:   PetscFree(A->data);
112:   PetscObjectChangeTypeName((PetscObject)A,0);
113:   return(0);
114: }

116: /*@
117:   MatCreateSeqCUFFT - Creates a matrix object that provides sequential FFT via the external package CUFFT

119:   Collective on MPI_Comm

121:   Input Parameters:
122: + comm - MPI communicator, set to PETSC_COMM_SELF
123: . ndim - the ndim-dimensional transform
124: - dim  - array of size ndim, dim[i] contains the vector length in the i-dimension

126:   Output Parameter:
127: . A - the matrix

129:   Options Database Keys:
130: . -mat_cufft_plannerflags - set CUFFT planner flags

132:   Level: intermediate
133: @*/
134: PetscErrorCode  MatCreateSeqCUFFT(MPI_Comm comm, PetscInt ndim, const PetscInt dim[], Mat *A)
135: {
136:   Mat_CUFFT      *cufft;
137:   PetscInt       m, d;

141:   if (ndim < 0) SETERRQ1(PETSC_COMM_SELF, PETSC_ERR_USER, "ndim %d must be > 0", ndim);
142:   MatCreate(comm, A);
143:   m    = 1;
144:   for (d = 0; d < ndim; ++d) {
145:     if (dim[d] < 0) SETERRQ2(PETSC_COMM_SELF, PETSC_ERR_USER, "dim[%d]=%d must be > 0", d, dim[d]);
146:     m *= dim[d];
147:   }
148:   MatSetSizes(*A, m, m, m, m);
149:   PetscObjectChangeTypeName((PetscObject)*A, MATSEQCUFFT);

151:   PetscNewLog(*A,&cufft);
152:   (*A)->data = (void*) cufft;
153:   PetscMalloc1(ndim+1, &cufft->dim);
154:   PetscMemcpy(cufft->dim, dim, ndim*sizeof(PetscInt));

156:   cufft->ndim       = ndim;
157:   cufft->p_forward  = 0;
158:   cufft->p_backward = 0;
159:   cufft->dim[ndim]  = m;

161:   /* GPU memory allocation */
162:   cudaMalloc((void**) &cufft->devArray, sizeof(cufftComplex)*m);

164:   (*A)->ops->mult          = MatMult_SeqCUFFT;
165:   (*A)->ops->multtranspose = MatMultTranspose_SeqCUFFT;
166:   (*A)->assembled          = PETSC_TRUE;
167:   (*A)->ops->destroy       = MatDestroy_SeqCUFFT;

169:   /* get runtime options */
170:   PetscOptionsBegin(comm, ((PetscObject)(*A))->prefix, "CUFFT Options", "Mat");
171:   PetscOptionsEnd();
172:   return(0);
173: }