Actual source code: ex47cu.cu
petsc-3.10.5 2019-03-28
1: static char help[] = "Solves -Laplacian u - exp(u) = 0, 0 < x < 1 using GPU\n\n";
2: /*
3: Same as ex47.c except it also uses the GPU to evaluate the function
4: */
6: #include <petscdm.h>
7: #include <petscdmda.h>
8: #include <petscsnes.h>
9: #include <petsccuda.h>
11: #include <thrust/device_ptr.h>
12: #include <thrust/for_each.h>
13: #include <thrust/tuple.h>
14: #include <thrust/iterator/constant_iterator.h>
15: #include <thrust/iterator/counting_iterator.h>
16: #include <thrust/iterator/zip_iterator.h>
18: extern PetscErrorCode ComputeFunction(SNES,Vec,Vec,void*), ComputeJacobian(SNES,Vec,Mat,Mat,void*);
19: PetscBool useCUDA = PETSC_FALSE;
21: int main(int argc,char **argv)
22: {
23: SNES snes;
24: Vec x,f;
25: Mat J;
26: DM da;
28: char *tmp,typeName[256];
29: PetscBool flg;
31: PetscInitialize(&argc,&argv,(char*)0,help);if (ierr) return ierr;
32: PetscOptionsGetString(NULL,NULL,"-dm_vec_type",typeName,256,&flg);
33: if (flg) {
34: PetscStrstr(typeName,"cuda",&tmp);
35: if (tmp) useCUDA = PETSC_TRUE;
36: }
38: DMDACreate1d(PETSC_COMM_WORLD,DM_BOUNDARY_NONE,8,1,1,NULL,&da);
39: DMSetFromOptions(da);
40: DMSetUp(da);
41: DMCreateGlobalVector(da,&x); VecDuplicate(x,&f);
42: DMSetMatType(da,MATAIJ);
43: DMCreateMatrix(da,&J);
45: SNESCreate(PETSC_COMM_WORLD,&snes);
46: SNESSetFunction(snes,f,ComputeFunction,da);
47: SNESSetJacobian(snes,J,J,ComputeJacobian,da);
48: SNESSetFromOptions(snes);
49: SNESSolve(snes,NULL,x);
51: MatDestroy(&J);
52: VecDestroy(&x);
53: VecDestroy(&f);
54: SNESDestroy(&snes);
55: DMDestroy(&da);
57: PetscFinalize();
58: return ierr;
59: }
61: struct ApplyStencil
62: {
63: template <typename Tuple>
64: __host__ __device__
65: void operator()(Tuple t)
66: {
67: /* f = (2*x_i - x_(i+1) - x_(i-1))/h - h*exp(x_i) */
68: thrust::get<0>(t) = 1;
69: if ((thrust::get<4>(t) > 0) && (thrust::get<4>(t) < thrust::get<5>(t)-1)) {
70: thrust::get<0>(t) = (2.0*thrust::get<1>(t) - thrust::get<2>(t) - thrust::get<3>(t)) / (thrust::get<6>(t)) - (thrust::get<6>(t))*exp(thrust::get<1>(t));
71: } else if (thrust::get<4>(t) == 0) {
72: thrust::get<0>(t) = thrust::get<1>(t) / (thrust::get<6>(t));
73: } else if (thrust::get<4>(t) == thrust::get<5>(t)-1) {
74: thrust::get<0>(t) = thrust::get<1>(t) / (thrust::get<6>(t));
75: }
76: }
77: };
79: PetscErrorCode ComputeFunction(SNES snes,Vec x,Vec f,void *ctx)
80: {
81: PetscInt i,Mx,xs,xm,xstartshift,xendshift,fstart,lsize;
82: PetscScalar *xx,*ff,hx;
83: DM da = (DM) ctx;
84: Vec xlocal;
85: PetscErrorCode ierr;
86: PetscMPIInt rank,size;
87: MPI_Comm comm;
88: PetscScalar const *xarray;
89: PetscScalar *farray;
91: DMDAGetInfo(da,PETSC_IGNORE,&Mx,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE);
92: hx = 1.0/(PetscReal)(Mx-1);
93: DMGetLocalVector(da,&xlocal);
94: DMGlobalToLocalBegin(da,x,INSERT_VALUES,xlocal);
95: DMGlobalToLocalEnd(da,x,INSERT_VALUES,xlocal);
97: if (useCUDA) {
98: VecCUDAGetArrayRead(xlocal,&xarray);
99: VecCUDAGetArrayWrite(f,&farray);
100: PetscObjectGetComm((PetscObject)da,&comm);
101: MPI_Comm_size(comm,&size);
102: MPI_Comm_rank(comm,&rank);
103: if (rank) xstartshift = 1;
104: else xstartshift = 0;
105: if (rank != size-1) xendshift = 1;
106: else xendshift = 0;
107: VecGetOwnershipRange(f,&fstart,NULL);
108: VecGetLocalSize(x,&lsize);
109: try {
110: thrust::for_each(
111: thrust::make_zip_iterator(
112: thrust::make_tuple(
113: thrust::device_ptr<PetscScalar>(farray),
114: thrust::device_ptr<const PetscScalar>(xarray + xstartshift),
115: thrust::device_ptr<const PetscScalar>(xarray + xstartshift + 1),
116: thrust::device_ptr<const PetscScalar>(xarray + xstartshift - 1),
117: thrust::counting_iterator<int>(fstart),
118: thrust::constant_iterator<int>(Mx),
119: thrust::constant_iterator<PetscScalar>(hx))),
120: thrust::make_zip_iterator(
121: thrust::make_tuple(
122: thrust::device_ptr<PetscScalar>(farray + lsize),
123: thrust::device_ptr<const PetscScalar>(xarray + lsize - xendshift),
124: thrust::device_ptr<const PetscScalar>(xarray + lsize - xendshift + 1),
125: thrust::device_ptr<const PetscScalar>(xarray + lsize - xendshift - 1),
126: thrust::counting_iterator<int>(fstart) + lsize,
127: thrust::constant_iterator<int>(Mx),
128: thrust::constant_iterator<PetscScalar>(hx))),
129: ApplyStencil());
130: }
131: catch (char *all) {
132: PetscPrintf(PETSC_COMM_WORLD, "Thrust is not working\n");
133: }
134: VecCUDARestoreArrayRead(xlocal,&xarray);
135: VecCUDARestoreArrayWrite(f,&farray);
136: } else {
137: DMDAVecGetArray(da,xlocal,&xx);
138: DMDAVecGetArray(da,f,&ff);
139: DMDAGetCorners(da,&xs,NULL,NULL,&xm,NULL,NULL);
141: for (i=xs; i<xs+xm; i++) {
142: if (i == 0 || i == Mx-1) ff[i] = xx[i]/hx;
143: else ff[i] = (2.0*xx[i] - xx[i-1] - xx[i+1])/hx - hx*PetscExpScalar(xx[i]);
144: }
145: DMDAVecRestoreArray(da,xlocal,&xx);
146: DMDAVecRestoreArray(da,f,&ff);
147: }
148: DMRestoreLocalVector(da,&xlocal);
149: // VecView(x,0);printf("f\n");
150: // VecView(f,0);
151: return 0;
153: }
154: PetscErrorCode ComputeJacobian(SNES snes,Vec x,Mat J,Mat B,void *ctx)
155: {
156: DM da = (DM) ctx;
157: PetscInt i,Mx,xm,xs;
158: PetscScalar hx,*xx;
159: Vec xlocal;
162: DMDAGetInfo(da,PETSC_IGNORE,&Mx,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE,PETSC_IGNORE);
163: hx = 1.0/(PetscReal)(Mx-1);
164: DMGetLocalVector(da,&xlocal);DMGlobalToLocalBegin(da,x,INSERT_VALUES,xlocal);
165: DMGlobalToLocalEnd(da,x,INSERT_VALUES,xlocal);
166: DMDAVecGetArray(da,xlocal,&xx);
167: DMDAGetCorners(da,&xs,NULL,NULL,&xm,NULL,NULL);
169: for (i=xs; i<xs+xm; i++) {
170: if (i == 0 || i == Mx-1) {
171: MatSetValue(J,i,i,1.0/hx,INSERT_VALUES);
172: } else {
173: MatSetValue(J,i,i-1,-1.0/hx,INSERT_VALUES);
174: MatSetValue(J,i,i,2.0/hx - hx*PetscExpScalar(xx[i]),INSERT_VALUES);
175: MatSetValue(J,i,i+1,-1.0/hx,INSERT_VALUES);
176: }
177: }
178: MatAssemblyBegin(J,MAT_FINAL_ASSEMBLY);
179: MatAssemblyEnd(J,MAT_FINAL_ASSEMBLY);
180: DMDAVecRestoreArray(da,xlocal,&xx);
181: DMRestoreLocalVector(da,&xlocal);
182: return 0;
183: }
187: /*TEST
189: build:
190: requires: cuda
192: test:
193: args: -snes_monitor_short -dm_vec_type cuda
195: TEST*/