Actual source code: submat.c

petsc-3.9.4 2018-09-11
Report Typos and Errors

  2:  #include <petsc/private/matimpl.h>

  4: typedef struct {
  5:   IS          isrow,iscol;      /* rows and columns in submatrix, only used to check consistency */
  6:   Vec         left,right;       /* optional scaling */
  7:   Vec         olwork,orwork;    /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
  8:   Vec         lwork,rwork;      /* work vectors inside the scatters */
  9:   VecScatter  lrestrict,rprolong;
 10:   Mat         A;
 11:   PetscScalar scale;
 12: } Mat_SubVirtual;

 14: static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
 15: {
 16:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

 20:   if (!Na->left) {
 21:     *xx = x;
 22:   } else {
 23:     if (!Na->olwork) {
 24:       VecDuplicate(Na->left,&Na->olwork);
 25:     }
 26:     VecPointwiseMult(Na->olwork,x,Na->left);
 27:     *xx  = Na->olwork;
 28:   }
 29:   return(0);
 30: }

 32: static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
 33: {
 34:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

 38:   if (!Na->right) {
 39:     *xx = x;
 40:   } else {
 41:     if (!Na->orwork) {
 42:       VecDuplicate(Na->right,&Na->orwork);
 43:     }
 44:     VecPointwiseMult(Na->orwork,x,Na->right);
 45:     *xx  = Na->orwork;
 46:   }
 47:   return(0);
 48: }

 50: static PetscErrorCode PostScaleLeft(Mat N,Vec x)
 51: {
 52:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

 56:   if (Na->left) {
 57:     VecPointwiseMult(x,x,Na->left);
 58:   }
 59:   return(0);
 60: }

 62: static PetscErrorCode PostScaleRight(Mat N,Vec x)
 63: {
 64:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

 68:   if (Na->right) {
 69:     VecPointwiseMult(x,x,Na->right);
 70:   }
 71:   return(0);
 72: }

 74: static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
 75: {
 76:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

 79:   Na->scale *= scale;
 80:   return(0);
 81: }

 83: static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
 84: {
 85:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

 89:   if (left) {
 90:     if (!Na->left) {
 91:       VecDuplicate(left,&Na->left);
 92:       VecCopy(left,Na->left);
 93:     } else {
 94:       VecPointwiseMult(Na->left,left,Na->left);
 95:     }
 96:   }
 97:   if (right) {
 98:     if (!Na->right) {
 99:       VecDuplicate(right,&Na->right);
100:       VecCopy(right,Na->right);
101:     } else {
102:       VecPointwiseMult(Na->right,right,Na->right);
103:     }
104:   }
105:   return(0);
106: }

108: static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
109: {
110:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
111:   Vec            xx  = 0;

115:   PreScaleRight(N,x,&xx);
116:   VecZeroEntries(Na->rwork);
117:   VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
118:   VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
119:   MatMult(Na->A,Na->rwork,Na->lwork);
120:   VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);
121:   VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);
122:   PostScaleLeft(N,y);
123:   VecScale(y,Na->scale);
124:   return(0);
125: }

127: static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
128: {
129:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
130:   Vec            xx  = 0;

134:   PreScaleRight(N,v1,&xx);
135:   VecZeroEntries(Na->rwork);
136:   VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
137:   VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
138:   MatMult(Na->A,Na->rwork,Na->lwork);
139:   if (v2 == v3) {
140:     if (Na->scale == (PetscScalar)1.0 && !Na->left) {
141:       VecScatterBegin(Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);
142:       VecScatterEnd  (Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);
143:     } else {
144:       if (!Na->olwork) {VecDuplicate(v3,&Na->olwork);}
145:       VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);
146:       VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);
147:       PostScaleLeft(N,Na->olwork);
148:       VecAXPY(v3,Na->scale,Na->olwork);
149:     }
150:   } else {
151:     VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);
152:     VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);
153:     PostScaleLeft(N,v3);
154:     VecAYPX(v3,Na->scale,v2);
155:   }
156:   return(0);
157: }

159: static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
160: {
161:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
162:   Vec            xx  = 0;

166:   PreScaleLeft(N,x,&xx);
167:   VecZeroEntries(Na->lwork);
168:   VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
169:   VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
170:   MatMultTranspose(Na->A,Na->lwork,Na->rwork);
171:   VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);
172:   VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);
173:   PostScaleRight(N,y);
174:   VecScale(y,Na->scale);
175:   return(0);
176: }

178: static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
179: {
180:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
181:   Vec            xx  = 0;

185:   PreScaleLeft(N,v1,&xx);
186:   VecZeroEntries(Na->lwork);
187:   VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
188:   VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
189:   MatMultTranspose(Na->A,Na->lwork,Na->rwork);
190:   if (v2 == v3) {
191:     if (Na->scale == (PetscScalar)1.0 && !Na->right) {
192:       VecScatterBegin(Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);
193:       VecScatterEnd  (Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);
194:     } else {
195:       if (!Na->orwork) {VecDuplicate(v3,&Na->orwork);}
196:       VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);
197:       VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);
198:       PostScaleRight(N,Na->orwork);
199:       VecAXPY(v3,Na->scale,Na->orwork);
200:     }
201:   } else {
202:     VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);
203:     VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);
204:     PostScaleRight(N,v3);
205:     VecAYPX(v3,Na->scale,v2);
206:   }
207:   return(0);
208: }

210: static PetscErrorCode MatDestroy_SubMatrix(Mat N)
211: {
212:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

216:   ISDestroy(&Na->isrow);
217:   ISDestroy(&Na->iscol);
218:   VecDestroy(&Na->left);
219:   VecDestroy(&Na->right);
220:   VecDestroy(&Na->olwork);
221:   VecDestroy(&Na->orwork);
222:   VecDestroy(&Na->lwork);
223:   VecDestroy(&Na->rwork);
224:   VecScatterDestroy(&Na->lrestrict);
225:   VecScatterDestroy(&Na->rprolong);
226:   MatDestroy(&Na->A);
227:   PetscFree(N->data);
228:   return(0);
229: }

231: /*@
232:    MatCreateSubMatrixVirtual - Creates a virtual matrix that acts as a submatrix

234:    Collective on Mat

236:    Input Parameters:
237: +  A - matrix that we will extract a submatrix of
238: .  isrow - rows to be present in the submatrix
239: -  iscol - columns to be present in the submatrix

241:    Output Parameters:
242: .  newmat - new matrix

244:    Level: developer

246:    Notes:
247:    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.

249: .seealso: MatCreateSubMatrix(), MatSubMatrixVirtualUpdate()
250: @*/
251: PetscErrorCode MatCreateSubMatrixVirtual(Mat A,IS isrow,IS iscol,Mat *newmat)
252: {
253:   Vec            left,right;
254:   PetscInt       m,n;
255:   Mat            N;
256:   Mat_SubVirtual *Na;

264:   *newmat = 0;

266:   MatCreate(PetscObjectComm((PetscObject)A),&N);
267:   ISGetLocalSize(isrow,&m);
268:   ISGetLocalSize(iscol,&n);
269:   MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);
270:   PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);

272:   PetscNewLog(N,&Na);
273:   N->data   = (void*)Na;
274:   PetscObjectReference((PetscObject)A);
275:   PetscObjectReference((PetscObject)isrow);
276:   PetscObjectReference((PetscObject)iscol);
277:   Na->A     = A;
278:   Na->isrow = isrow;
279:   Na->iscol = iscol;
280:   Na->scale = 1.0;

282:   N->ops->destroy          = MatDestroy_SubMatrix;
283:   N->ops->mult             = MatMult_SubMatrix;
284:   N->ops->multadd          = MatMultAdd_SubMatrix;
285:   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
286:   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
287:   N->ops->scale            = MatScale_SubMatrix;
288:   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;

290:   MatSetBlockSizesFromMats(N,A,A);
291:   PetscLayoutSetUp(N->rmap);
292:   PetscLayoutSetUp(N->cmap);

294:   MatCreateVecs(A,&Na->rwork,&Na->lwork);
295:   VecCreate(PetscObjectComm((PetscObject)isrow),&left);
296:   VecCreate(PetscObjectComm((PetscObject)iscol),&right);
297:   VecSetSizes(left,m,PETSC_DETERMINE);
298:   VecSetSizes(right,n,PETSC_DETERMINE);
299:   VecSetUp(left);
300:   VecSetUp(right);
301:   VecScatterCreate(Na->lwork,isrow,left,NULL,&Na->lrestrict);
302:   VecScatterCreate(right,NULL,Na->rwork,iscol,&Na->rprolong);
303:   VecDestroy(&left);
304:   VecDestroy(&right);

306:   N->assembled = PETSC_TRUE;

308:   MatSetUp(N);

310:   *newmat      = N;
311:   return(0);
312: }


315: /*@
316:    MatSubMatrixVirtualUpdate - Updates a submatrix

318:    Collective on Mat

320:    Input Parameters:
321: +  N - submatrix to update
322: .  A - full matrix in the submatrix
323: .  isrow - rows in the update (same as the first time the submatrix was created)
324: -  iscol - columns in the update (same as the first time the submatrix was created)

326:    Level: developer

328:    Notes:
329:    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.

331: .seealso: MatCreateSubMatrixVirtual()
332: @*/
333: PetscErrorCode  MatSubMatrixVirtualUpdate(Mat N,Mat A,IS isrow,IS iscol)
334: {
336:   PetscBool      flg;
337:   Mat_SubVirtual *Na;

344:   PetscObjectTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);
345:   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix has wrong type");

347:   Na   = (Mat_SubVirtual*)N->data;
348:   ISEqual(isrow,Na->isrow,&flg);
349:   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
350:   ISEqual(iscol,Na->iscol,&flg);
351:   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");

353:   PetscObjectReference((PetscObject)A);
354:   MatDestroy(&Na->A);
355:   Na->A = A;

357:   Na->scale = 1.0;
358:   VecDestroy(&Na->left);
359:   VecDestroy(&Na->right);
360:   return(0);
361: }