Actual source code: submat.c
petsc-3.11.4 2019-09-28
2: #include <petsc/private/matimpl.h>
4: typedef struct {
5: IS isrow,iscol; /* rows and columns in submatrix, only used to check consistency */
6: Vec left,right; /* optional scaling */
7: Vec olwork,orwork; /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
8: Vec lwork,rwork; /* work vectors inside the scatters */
9: Vec dshift;
10: VecScatter lrestrict,rprolong;
11: Mat A;
12: PetscScalar vscale, axpy_vscale;
13: PetscScalar vshift, axpy_vshift;
14: } Mat_SubVirtual;
16: static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
17: {
18: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
22: if (!Na->left) {
23: *xx = x;
24: } else {
25: if (!Na->olwork) {
26: VecDuplicate(Na->left,&Na->olwork);
27: }
28: VecPointwiseMult(Na->olwork,x,Na->left);
29: *xx = Na->olwork;
30: }
31: return(0);
32: }
34: static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
35: {
36: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
40: if (!Na->right) {
41: *xx = x;
42: } else {
43: if (!Na->orwork) {
44: VecDuplicate(Na->right,&Na->orwork);
45: }
46: VecPointwiseMult(Na->orwork,x,Na->right);
47: *xx = Na->orwork;
48: }
49: return(0);
50: }
52: static PetscErrorCode PostScaleLeft(Mat N,Vec x)
53: {
54: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
58: if (Na->left) {
59: VecPointwiseMult(x,x,Na->left);
60: }
61: return(0);
62: }
64: static PetscErrorCode PostScaleRight(Mat N,Vec x)
65: {
66: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
70: if (Na->right) {
71: VecPointwiseMult(x,x,Na->right);
72: }
73: return(0);
74: }
76: /*
77: Y = vscale*Y + diag(dshift)*X + vshift*X
79: On input Y already contains A*x
80: */
81: static PetscErrorCode MatSubmatShiftAndScale(Mat A,Vec X,Vec Y)
82: {
83: Mat_SubVirtual *Na = (Mat_SubVirtual*)A->data;
87: if (Na->dshift) { /* get arrays because there is no VecPointwiseMultAdd() */
88: PetscInt i,m;
89: const PetscScalar *x,*d;
90: PetscScalar *y;
91: VecGetLocalSize(X,&m);
92: VecGetArrayRead(Na->dshift,&d);
93: VecGetArrayRead(X,&x);
94: VecGetArray(Y,&y);
95: for (i=0; i<m; i++) y[i] = Na->vscale*y[i] + d[i]*x[i];
96: VecRestoreArrayRead(Na->dshift,&d);
97: VecRestoreArrayRead(X,&x);
98: VecRestoreArray(Y,&y);
99: } else {
100: VecScale(Y,Na->vscale);
101: }
102: if (Na->vshift != 0.0) {VecAXPY(Y,Na->vshift,X);} /* if test is for non-square matrices */
103: return(0);
104: }
106: static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar a)
107: {
108: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
112: Na->vscale *= a;
113: Na->vshift *= a;
114: if (Na->dshift) {
115: VecScale(Na->dshift,a);
116: }
117: Na->axpy_vscale *= a;
118: return(0);
119: }
121: static PetscErrorCode MatShift_SubMatrix(Mat N,PetscScalar a)
122: {
123: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
127: if (Na->left || Na->right) {
128: if (!Na->dshift) {
129: VecDuplicate(Na->left ? Na->left : Na->right, &Na->dshift);
130: VecSet(Na->dshift,a);
131: } else {
132: if (Na->left) {VecPointwiseMult(Na->dshift,Na->dshift,Na->left);}
133: if (Na->right) {VecPointwiseMult(Na->dshift,Na->dshift,Na->right);}
134: VecShift(Na->dshift,a);
135: }
136: if (Na->left) {VecPointwiseDivide(Na->dshift,Na->dshift,Na->left);}
137: if (Na->right) {VecPointwiseDivide(Na->dshift,Na->dshift,Na->right);}
138: } else Na->vshift += a;
139: return(0);
140: }
142: static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
143: {
144: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
148: if (left) {
149: if (!Na->left) {
150: VecDuplicate(left,&Na->left);
151: VecCopy(left,Na->left);
152: } else {
153: VecPointwiseMult(Na->left,left,Na->left);
154: }
155: }
156: if (right) {
157: if (!Na->right) {
158: VecDuplicate(right,&Na->right);
159: VecCopy(right,Na->right);
160: } else {
161: VecPointwiseMult(Na->right,right,Na->right);
162: }
163: }
164: return(0);
165: }
167: static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
168: {
169: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
170: Vec xx = 0;
174: PreScaleRight(N,x,&xx);
175: VecZeroEntries(Na->rwork);
176: VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
177: VecScatterEnd (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
178: MatMult(Na->A,Na->rwork,Na->lwork);
179: VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);
180: VecScatterEnd (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);
181: MatSubmatShiftAndScale(N,xx,y);
182: PostScaleLeft(N,y);
183: return(0);
184: }
186: static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
187: {
188: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
189: Vec xx = 0;
193: PreScaleRight(N,v1,&xx);
194: VecZeroEntries(Na->rwork);
195: VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
196: VecScatterEnd (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
197: MatMult(Na->A,Na->rwork,Na->lwork);
198: if (v2 == v3) {
199: if (!Na->olwork) {VecDuplicate(v3,&Na->olwork);}
200: VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);
201: VecScatterEnd (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);
202: MatSubmatShiftAndScale(N,xx,Na->olwork);
203: PostScaleLeft(N,Na->olwork);
204: VecAXPY(v3,1.0,Na->olwork);
205: } else {
206: VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);
207: VecScatterEnd (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);
208: MatSubmatShiftAndScale(N,xx,v3);
209: PostScaleLeft(N,v3);
210: VecAXPY(v3,1.0,v2);
211: }
212: return(0);
213: }
215: static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
216: {
217: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
218: Vec xx = 0;
222: PreScaleLeft(N,x,&xx);
223: VecZeroEntries(Na->lwork);
224: VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
225: VecScatterEnd (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
226: MatMultTranspose(Na->A,Na->lwork,Na->rwork);
227: VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);
228: VecScatterEnd (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);
229: MatSubmatShiftAndScale(N,xx,y);
230: PostScaleRight(N,y);
231: return(0);
232: }
234: static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
235: {
236: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
237: Vec xx = 0;
241: PreScaleLeft(N,v1,&xx);
242: VecZeroEntries(Na->lwork);
243: VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
244: VecScatterEnd (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
245: MatMultTranspose(Na->A,Na->lwork,Na->rwork);
246: if (v2 == v3) {
247: if (!Na->orwork) {VecDuplicate(v3,&Na->orwork);}
248: VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);
249: VecScatterEnd (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);
250: MatSubmatShiftAndScale(N,xx,Na->orwork);
251: PostScaleRight(N,Na->orwork);
252: VecAXPY(v3,1.0,Na->orwork);
253: } else {
254: VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);
255: VecScatterEnd (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);
256: MatSubmatShiftAndScale(N,xx,v3);
257: PostScaleRight(N,v3);
258: VecAXPY(v3,1.0,v2);
259: }
260: return(0);
261: }
263: static PetscErrorCode MatDestroy_SubMatrix(Mat N)
264: {
265: Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
269: ISDestroy(&Na->isrow);
270: ISDestroy(&Na->iscol);
271: VecDestroy(&Na->left);
272: VecDestroy(&Na->right);
273: VecDestroy(&Na->olwork);
274: VecDestroy(&Na->orwork);
275: VecDestroy(&Na->lwork);
276: VecDestroy(&Na->rwork);
277: VecDestroy(&Na->dshift);
278: VecScatterDestroy(&Na->lrestrict);
279: VecScatterDestroy(&Na->rprolong);
280: MatDestroy(&Na->A);
281: PetscFree(N->data);
282: return(0);
283: }
285: /*@
286: MatCreateSubMatrixVirtual - Creates a virtual matrix that acts as a submatrix
288: Collective on Mat
290: Input Parameters:
291: + A - matrix that we will extract a submatrix of
292: . isrow - rows to be present in the submatrix
293: - iscol - columns to be present in the submatrix
295: Output Parameters:
296: . newmat - new matrix
298: Level: developer
300: Notes:
301: Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
303: .seealso: MatCreateSubMatrix(), MatSubMatrixVirtualUpdate()
304: @*/
305: PetscErrorCode MatCreateSubMatrixVirtual(Mat A,IS isrow,IS iscol,Mat *newmat)
306: {
307: Vec left,right;
308: PetscInt m,n;
309: Mat N;
310: Mat_SubVirtual *Na;
318: *newmat = 0;
320: MatCreate(PetscObjectComm((PetscObject)A),&N);
321: ISGetLocalSize(isrow,&m);
322: ISGetLocalSize(iscol,&n);
323: MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);
324: PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);
326: PetscNewLog(N,&Na);
327: N->data = (void*)Na;
328: PetscObjectReference((PetscObject)A);
329: PetscObjectReference((PetscObject)isrow);
330: PetscObjectReference((PetscObject)iscol);
331: Na->A = A;
332: Na->isrow = isrow;
333: Na->iscol = iscol;
334: Na->vscale = 1.0;
335: Na->vshift = 0.0;
337: N->ops->destroy = MatDestroy_SubMatrix;
338: N->ops->mult = MatMult_SubMatrix;
339: N->ops->multadd = MatMultAdd_SubMatrix;
340: N->ops->multtranspose = MatMultTranspose_SubMatrix;
341: N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
342: N->ops->scale = MatScale_SubMatrix;
343: N->ops->diagonalscale = MatDiagonalScale_SubMatrix;
344: N->ops->shift = MatShift_SubMatrix;
346: MatSetBlockSizesFromMats(N,A,A);
347: PetscLayoutSetUp(N->rmap);
348: PetscLayoutSetUp(N->cmap);
350: MatCreateVecs(A,&Na->rwork,&Na->lwork);
351: VecCreate(PetscObjectComm((PetscObject)isrow),&left);
352: VecCreate(PetscObjectComm((PetscObject)iscol),&right);
353: VecSetSizes(left,m,PETSC_DETERMINE);
354: VecSetSizes(right,n,PETSC_DETERMINE);
355: VecSetUp(left);
356: VecSetUp(right);
357: VecScatterCreate(Na->lwork,isrow,left,NULL,&Na->lrestrict);
358: VecScatterCreate(right,NULL,Na->rwork,iscol,&Na->rprolong);
359: VecDestroy(&left);
360: VecDestroy(&right);
362: N->assembled = PETSC_TRUE;
364: MatSetUp(N);
366: *newmat = N;
367: return(0);
368: }
371: /*@
372: MatSubMatrixVirtualUpdate - Updates a submatrix
374: Collective on Mat
376: Input Parameters:
377: + N - submatrix to update
378: . A - full matrix in the submatrix
379: . isrow - rows in the update (same as the first time the submatrix was created)
380: - iscol - columns in the update (same as the first time the submatrix was created)
382: Level: developer
384: Notes:
385: Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.
387: .seealso: MatCreateSubMatrixVirtual()
388: @*/
389: PetscErrorCode MatSubMatrixVirtualUpdate(Mat N,Mat A,IS isrow,IS iscol)
390: {
392: PetscBool flg;
393: Mat_SubVirtual *Na;
400: PetscObjectTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);
401: if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix has wrong type");
403: Na = (Mat_SubVirtual*)N->data;
404: ISEqual(isrow,Na->isrow,&flg);
405: if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
406: ISEqual(iscol,Na->iscol,&flg);
407: if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");
409: PetscObjectReference((PetscObject)A);
410: MatDestroy(&Na->A);
411: Na->A = A;
413: Na->vshift = 0.0;
414: Na->vscale = 1.0;
415: VecDestroy(&Na->left);
416: VecDestroy(&Na->right);
417: return(0);
418: }