1: #include <petsc/private/isimpl.h>
2: #include <petsc/private/vecimpl.h> /*I "petscvec.h" I*/
4: PETSC_INTERN PetscErrorCode VecScatterCUSPIndicesCreate_PtoP(PetscInt,PetscInt*,PetscInt,PetscInt*,PetscCUSPIndices*);
8: /*@
9: VecScatterInitializeForGPU - Initializes a generalized scatter from one vector to
10: another for GPU based computation. Effectively, this function creates all the
11: necessary indexing buffers and work vectors needed to move data only those data points
12: in a vector which need to be communicated across ranks. This is done at the first time
13: this function is called. Currently, this only used in the context of the parallel
14: SpMV call in MatMult_MPIAIJCUSP (in mpi/mpicusp/mpiaijcusp.cu) or MatMult_MPIAIJCUSPARSE
15: (in mpi/mpicusparse/mpiaijcusparse.cu). This function is executed before the call to
16: MatMult. This enables the memory transfers to be overlapped with the MatMult SpMV kernel
17: call.
19: Input Parameters:
20: + inctx - scatter context generated by VecScatterCreate()
21: . x - the vector from which we scatter
22: - mode - the scattering mode, usually SCATTER_FORWARD. The available modes are:
23: SCATTER_FORWARD or SCATTER_REVERSE 25: Level: intermediate
27: .seealso: VecScatterCreate(), VecScatterEnd()
28: @*/
29: PetscErrorCodeVecScatterInitializeForGPU(VecScatter inctx,Vec x,ScatterMode mode) 30: {
31: VecScatter_MPI_General *to,*from;
32: PetscErrorCode ierr;
33: PetscInt i,*indices,*sstartsSends,*sstartsRecvs,nrecvs,nsends,bs;
34: PetscBool isSeq1,isSeq2;
37: VecScatterIsSequential_Private((VecScatter_Common*)inctx->fromdata,&isSeq1);
38: VecScatterIsSequential_Private((VecScatter_Common*)inctx->todata,&isSeq2);
39: if (isSeq1 || isSeq2) {
40: return(0);
41: }
42: if (mode & SCATTER_REVERSE) {
43: to = (VecScatter_MPI_General*)inctx->fromdata;
44: from = (VecScatter_MPI_General*)inctx->todata;
45: } else {
46: to = (VecScatter_MPI_General*)inctx->todata;
47: from = (VecScatter_MPI_General*)inctx->fromdata;
48: }
49: bs = to->bs;
50: nrecvs = from->n;
51: nsends = to->n;
52: indices = to->indices;
53: sstartsSends = to->starts;
54: sstartsRecvs = from->starts;
55: if (x->valid_GPU_array != PETSC_CUSP_UNALLOCATED && (nsends>0 || nrecvs>0)) {
56: if (!inctx->spptr) {
57: PetscInt k,*tindicesSends,*sindicesSends,*tindicesRecvs,*sindicesRecvs;
58: PetscInt ns = sstartsSends[nsends],nr = sstartsRecvs[nrecvs];
59: /* Here we create indices for both the senders and receivers. */
60: PetscMalloc1(ns,&tindicesSends);
61: PetscMalloc1(nr,&tindicesRecvs);
63: PetscMemcpy(tindicesSends,indices,ns*sizeof(PetscInt));
64: PetscMemcpy(tindicesRecvs,from->indices,nr*sizeof(PetscInt));
66: PetscSortRemoveDupsInt(&ns,tindicesSends);
67: PetscSortRemoveDupsInt(&nr,tindicesRecvs);
69: PetscMalloc1(bs*ns,&sindicesSends);
70: PetscMalloc1(from->bs*nr,&sindicesRecvs);
72: /* sender indices */
73: for (i=0; i<ns; i++) {
74: for (k=0; k<bs; k++) sindicesSends[i*bs+k] = tindicesSends[i]+k;
75: }
76: PetscFree(tindicesSends);
78: /* receiver indices */
79: for (i=0; i<nr; i++) {
80: for (k=0; k<from->bs; k++) sindicesRecvs[i*from->bs+k] = tindicesRecvs[i]+k;
81: }
82: PetscFree(tindicesRecvs);
84: /* create GPU indices, work vectors, ... */
85: VecScatterCUSPIndicesCreate_PtoP(ns*bs,sindicesSends,nr*from->bs,sindicesRecvs,(PetscCUSPIndices*)&inctx->spptr);
86: PetscFree(sindicesSends);
87: PetscFree(sindicesRecvs);
88: }
89: }
90: return(0);
91: }
95: /*@
96: VecScatterFinalizeForGPU - Finalizes a generalized scatter from one vector to
97: another for GPU based computation. Effectively, this function resets the temporary
98: buffer flags. Currently, this only used in the context of the parallel SpMV call in
99: in MatMult_MPIAIJCUSP (in mpi/mpicusp/mpiaijcusp.cu) or MatMult_MPIAIJCUSPARSE
100: (in mpi/mpicusparse/mpiaijcusparse.cu). Once the MatMultAdd is finished,
101: the GPU temporary buffers used for messaging are no longer valid.
103: Input Parameters:
104: + inctx - scatter context generated by VecScatterCreate()
106: Level: intermediate
108: @*/
109: PetscErrorCodeVecScatterFinalizeForGPU(VecScatter inctx)110: {
112: return(0);
113: }