Actual source code: transm.c

  1: #include <../src/mat/impls/shell/shell.h>

  3: static PetscErrorCode MatMult_Transpose(Mat N, Vec x, Vec y)
  4: {
  5:   Mat A;

  7:   PetscFunctionBegin;
  8:   PetscCall(MatShellGetContext(N, &A));
  9:   PetscCall(MatMultTranspose(A, x, y));
 10:   PetscFunctionReturn(PETSC_SUCCESS);
 11: }

 13: static PetscErrorCode MatMultTranspose_Transpose(Mat N, Vec x, Vec y)
 14: {
 15:   Mat A;

 17:   PetscFunctionBegin;
 18:   PetscCall(MatShellGetContext(N, &A));
 19:   PetscCall(MatMult(A, x, y));
 20:   PetscFunctionReturn(PETSC_SUCCESS);
 21: }

 23: static PetscErrorCode MatDestroy_Transpose(Mat N)
 24: {
 25:   Mat A;

 27:   PetscFunctionBegin;
 28:   PetscCall(MatShellGetContext(N, &A));
 29:   PetscCall(MatDestroy(&A));
 30:   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatTransposeGetMat_C", NULL));
 31:   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatProductSetFromOptions_anytype_C", NULL));
 32:   PetscCall(PetscObjectComposeFunction((PetscObject)N, "MatShellSetContext_C", NULL));
 33:   PetscFunctionReturn(PETSC_SUCCESS);
 34: }

 36: static PetscErrorCode MatDuplicate_Transpose(Mat N, MatDuplicateOption op, Mat *m)
 37: {
 38:   Mat A, C;

 40:   PetscFunctionBegin;
 41:   PetscCall(MatShellGetContext(N, &A));
 42:   PetscCall(MatDuplicate(A, op, &C));
 43:   PetscCall(MatCreateTranspose(C, m));
 44:   PetscCall(MatDestroy(&C));
 45:   if (op == MAT_COPY_VALUES) PetscCall(MatCopy(N, *m, SAME_NONZERO_PATTERN));
 46:   PetscFunctionReturn(PETSC_SUCCESS);
 47: }

 49: static PetscErrorCode MatHasOperation_Transpose(Mat mat, MatOperation op, PetscBool *has)
 50: {
 51:   Mat A;

 53:   PetscFunctionBegin;
 54:   PetscCall(MatShellGetContext(mat, &A));
 55:   *has = PETSC_FALSE;
 56:   if (op == MATOP_MULT || op == MATOP_MULT_ADD) {
 57:     PetscCall(MatHasOperation(A, MATOP_MULT_TRANSPOSE, has));
 58:   } else if (op == MATOP_MULT_TRANSPOSE || op == MATOP_MULT_TRANSPOSE_ADD) {
 59:     PetscCall(MatHasOperation(A, MATOP_MULT, has));
 60:   } else if (((void **)mat->ops)[op]) *has = PETSC_TRUE;
 61:   PetscFunctionReturn(PETSC_SUCCESS);
 62: }

 64: static PetscErrorCode MatProductSetFromOptions_Transpose(Mat D)
 65: {
 66:   Mat            A, B, C, Ain, Bin, Cin;
 67:   PetscBool      Aistrans, Bistrans, Cistrans;
 68:   PetscInt       Atrans, Btrans, Ctrans;
 69:   MatProductType ptype;

 71:   PetscFunctionBegin;
 72:   MatCheckProduct(D, 1);
 73:   A = D->product->A;
 74:   B = D->product->B;
 75:   C = D->product->C;
 76:   PetscCall(PetscObjectTypeCompare((PetscObject)A, MATTRANSPOSEVIRTUAL, &Aistrans));
 77:   PetscCall(PetscObjectTypeCompare((PetscObject)B, MATTRANSPOSEVIRTUAL, &Bistrans));
 78:   PetscCall(PetscObjectTypeCompare((PetscObject)C, MATTRANSPOSEVIRTUAL, &Cistrans));
 79:   PetscCheck(Aistrans || Bistrans || Cistrans, PetscObjectComm((PetscObject)D), PETSC_ERR_PLIB, "This should not happen");
 80:   Atrans = 0;
 81:   Ain    = A;
 82:   while (Aistrans) {
 83:     Atrans++;
 84:     PetscCall(MatTransposeGetMat(Ain, &Ain));
 85:     PetscCall(PetscObjectTypeCompare((PetscObject)Ain, MATTRANSPOSEVIRTUAL, &Aistrans));
 86:   }
 87:   Btrans = 0;
 88:   Bin    = B;
 89:   while (Bistrans) {
 90:     Btrans++;
 91:     PetscCall(MatTransposeGetMat(Bin, &Bin));
 92:     PetscCall(PetscObjectTypeCompare((PetscObject)Bin, MATTRANSPOSEVIRTUAL, &Bistrans));
 93:   }
 94:   Ctrans = 0;
 95:   Cin    = C;
 96:   while (Cistrans) {
 97:     Ctrans++;
 98:     PetscCall(MatTransposeGetMat(Cin, &Cin));
 99:     PetscCall(PetscObjectTypeCompare((PetscObject)Cin, MATTRANSPOSEVIRTUAL, &Cistrans));
100:   }
101:   Atrans = Atrans % 2;
102:   Btrans = Btrans % 2;
103:   Ctrans = Ctrans % 2;
104:   ptype  = D->product->type; /* same product type by default */
105:   if (Ain->symmetric == PETSC_BOOL3_TRUE) Atrans = 0;
106:   if (Bin->symmetric == PETSC_BOOL3_TRUE) Btrans = 0;
107:   if (Cin && Cin->symmetric == PETSC_BOOL3_TRUE) Ctrans = 0;

109:   if (Atrans || Btrans || Ctrans) {
110:     ptype = MATPRODUCT_UNSPECIFIED;
111:     switch (D->product->type) {
112:     case MATPRODUCT_AB:
113:       if (Atrans && Btrans) { /* At * Bt we do not have support for this */
114:         /* TODO custom implementation ? */
115:       } else if (Atrans) { /* At * B */
116:         ptype = MATPRODUCT_AtB;
117:       } else { /* A * Bt */
118:         ptype = MATPRODUCT_ABt;
119:       }
120:       break;
121:     case MATPRODUCT_AtB:
122:       if (Atrans && Btrans) { /* A * Bt */
123:         ptype = MATPRODUCT_ABt;
124:       } else if (Atrans) { /* A * B */
125:         ptype = MATPRODUCT_AB;
126:       } else { /* At * Bt we do not have support for this */
127:         /* TODO custom implementation ? */
128:       }
129:       break;
130:     case MATPRODUCT_ABt:
131:       if (Atrans && Btrans) { /* At * B */
132:         ptype = MATPRODUCT_AtB;
133:       } else if (Atrans) { /* At * Bt we do not have support for this */
134:         /* TODO custom implementation ? */
135:       } else { /* A * B */
136:         ptype = MATPRODUCT_AB;
137:       }
138:       break;
139:     case MATPRODUCT_PtAP:
140:       if (Atrans) { /* PtAtP */
141:         /* TODO custom implementation ? */
142:       } else { /* RARt */
143:         ptype = MATPRODUCT_RARt;
144:       }
145:       break;
146:     case MATPRODUCT_RARt:
147:       if (Atrans) { /* RAtRt */
148:         /* TODO custom implementation ? */
149:       } else { /* PtAP */
150:         ptype = MATPRODUCT_PtAP;
151:       }
152:       break;
153:     case MATPRODUCT_ABC:
154:       /* TODO custom implementation ? */
155:       break;
156:     default:
157:       SETERRQ(PetscObjectComm((PetscObject)D), PETSC_ERR_SUP, "ProductType %s is not supported", MatProductTypes[D->product->type]);
158:     }
159:   }
160:   PetscCall(MatProductReplaceMats(Ain, Bin, Cin, D));
161:   PetscCall(MatProductSetType(D, ptype));
162:   PetscCall(MatProductSetFromOptions(D));
163:   PetscFunctionReturn(PETSC_SUCCESS);
164: }

166: static PetscErrorCode MatGetDiagonal_Transpose(Mat N, Vec v)
167: {
168:   Mat A;

170:   PetscFunctionBegin;
171:   PetscCall(MatShellGetContext(N, &A));
172:   PetscCall(MatGetDiagonal(A, v));
173:   PetscFunctionReturn(PETSC_SUCCESS);
174: }

176: static PetscErrorCode MatCopy_Transpose(Mat A, Mat B, MatStructure str)
177: {
178:   Mat a, b;

180:   PetscFunctionBegin;
181:   PetscCall(MatShellGetContext(A, &a));
182:   PetscCall(MatShellGetContext(B, &b));
183:   PetscCall(MatCopy(a, b, str));
184:   PetscFunctionReturn(PETSC_SUCCESS);
185: }

187: static PetscErrorCode MatConvert_Transpose(Mat N, MatType newtype, MatReuse reuse, Mat *newmat)
188: {
189:   Mat         A;
190:   PetscScalar vscale = 1.0, vshift = 0.0;
191:   PetscBool   flg;

193:   PetscFunctionBegin;
194:   PetscCall(MatShellGetContext(N, &A));
195:   PetscCall(MatHasOperation(A, MATOP_TRANSPOSE, &flg));
196:   if (flg || N->ops->getrow) { /* if this condition is false, MatConvert_Shell() will be called in MatConvert_Basic(), so the following checks are not needed */
197:     PetscCall(MatShellGetScalingShifts(N, &vshift, &vscale, (Vec *)MAT_SHELL_NOT_ALLOWED, (Vec *)MAT_SHELL_NOT_ALLOWED, (Vec *)MAT_SHELL_NOT_ALLOWED, (Mat *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED, (IS *)MAT_SHELL_NOT_ALLOWED));
198:   }
199:   if (flg) {
200:     Mat B;

202:     PetscCall(MatTranspose(A, MAT_INITIAL_MATRIX, &B));
203:     if (reuse != MAT_INPLACE_MATRIX) {
204:       PetscCall(MatConvert(B, newtype, reuse, newmat));
205:       PetscCall(MatDestroy(&B));
206:     } else {
207:       PetscCall(MatConvert(B, newtype, MAT_INPLACE_MATRIX, &B));
208:       PetscCall(MatHeaderReplace(N, &B));
209:     }
210:   } else { /* use basic converter as fallback */
211:     flg = (PetscBool)(N->ops->getrow != NULL);
212:     PetscCall(MatConvert_Basic(N, newtype, reuse, newmat));
213:   }
214:   if (flg) {
215:     PetscCall(MatScale(*newmat, vscale));
216:     PetscCall(MatShift(*newmat, vshift));
217:   }
218:   PetscFunctionReturn(PETSC_SUCCESS);
219: }

221: static PetscErrorCode MatTransposeGetMat_Transpose(Mat N, Mat *M)
222: {
223:   PetscFunctionBegin;
224:   PetscCall(MatShellGetContext(N, M));
225:   PetscFunctionReturn(PETSC_SUCCESS);
226: }

228: /*@
229:   MatTransposeGetMat - Gets the `Mat` object stored inside a `MATTRANSPOSEVIRTUAL`

231:   Logically Collective

233:   Input Parameter:
234: . A - the `MATTRANSPOSEVIRTUAL` matrix

236:   Output Parameter:
237: . M - the matrix object stored inside `A`

239:   Level: intermediate

241: .seealso: [](ch_matrices), `Mat`, `MATTRANSPOSEVIRTUAL`, `MatCreateTranspose()`
242: @*/
243: PetscErrorCode MatTransposeGetMat(Mat A, Mat *M)
244: {
245:   PetscFunctionBegin;
248:   PetscAssertPointer(M, 2);
249:   PetscUseMethod(A, "MatTransposeGetMat_C", (Mat, Mat *), (A, M));
250:   PetscFunctionReturn(PETSC_SUCCESS);
251: }

253: /*MC
254:    MATTRANSPOSEVIRTUAL - "transpose" - A matrix type that represents a virtual transpose of a matrix

256:   Level: advanced

258:   Developer Notes:
259:   This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code

261:   Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage

263: .seealso: [](ch_matrices), `Mat`, `MATHERMITIANTRANSPOSEVIRTUAL`, `Mat`, `MatCreateHermitianTranspose()`, `MatCreateTranspose()`,
264:           `MATNORMALHERMITIAN`, `MATNORMAL`
265: M*/

267: /*@
268:   MatCreateTranspose - Creates a new matrix `MATTRANSPOSEVIRTUAL` object that behaves like A'

270:   Collective

272:   Input Parameter:
273: . A - the (possibly rectangular) matrix

275:   Output Parameter:
276: . N - the matrix that represents A'

278:   Level: intermediate

280:   Note:
281:   The transpose A' is NOT actually formed! Rather the new matrix
282:   object performs the matrix-vector product by using the `MatMultTranspose()` on
283:   the original matrix

285: .seealso: [](ch_matrices), `Mat`, `MATTRANSPOSEVIRTUAL`, `MatCreateNormal()`, `MatMult()`, `MatMultTranspose()`, `MatCreate()`,
286:           `MATNORMALHERMITIAN`
287: @*/
288: PetscErrorCode MatCreateTranspose(Mat A, Mat *N)
289: {
290:   VecType vtype;

292:   PetscFunctionBegin;
293:   PetscCall(MatCreate(PetscObjectComm((PetscObject)A), N));
294:   PetscCall(PetscLayoutReference(A->rmap, &((*N)->cmap)));
295:   PetscCall(PetscLayoutReference(A->cmap, &((*N)->rmap)));
296:   PetscCall(MatSetType(*N, MATSHELL));
297:   PetscCall(MatShellSetContext(*N, A));
298:   PetscCall(PetscObjectReference((PetscObject)A));

300:   PetscCall(MatSetBlockSizes(*N, PetscAbs(A->cmap->bs), PetscAbs(A->rmap->bs)));
301:   PetscCall(MatGetVecType(A, &vtype));
302:   PetscCall(MatSetVecType(*N, vtype));
303: #if defined(PETSC_HAVE_DEVICE)
304:   PetscCall(MatBindToCPU(*N, A->boundtocpu));
305: #endif
306:   PetscCall(MatSetUp(*N));

308:   PetscCall(MatShellSetOperation(*N, MATOP_DESTROY, (void (*)(void))MatDestroy_Transpose));
309:   PetscCall(MatShellSetOperation(*N, MATOP_MULT, (void (*)(void))MatMult_Transpose));
310:   PetscCall(MatShellSetOperation(*N, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Transpose));
311:   PetscCall(MatShellSetOperation(*N, MATOP_DUPLICATE, (void (*)(void))MatDuplicate_Transpose));
312:   PetscCall(MatShellSetOperation(*N, MATOP_HAS_OPERATION, (void (*)(void))MatHasOperation_Transpose));
313:   PetscCall(MatShellSetOperation(*N, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Transpose));
314:   PetscCall(MatShellSetOperation(*N, MATOP_COPY, (void (*)(void))MatCopy_Transpose));
315:   PetscCall(MatShellSetOperation(*N, MATOP_CONVERT, (void (*)(void))MatConvert_Transpose));

317:   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatTransposeGetMat_C", MatTransposeGetMat_Transpose));
318:   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatProductSetFromOptions_anytype_C", MatProductSetFromOptions_Transpose));
319:   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetContext_C", MatShellSetContext_Immutable));
320:   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable));
321:   PetscCall(PetscObjectComposeFunction((PetscObject)*N, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable));
322:   PetscCall(PetscObjectChangeTypeName((PetscObject)*N, MATTRANSPOSEVIRTUAL));
323:   PetscFunctionReturn(PETSC_SUCCESS);
324: }