Actual source code: mpimattransposematmult.c
petsc-3.14.6 2021-03-30
2: /*
3: Defines matrix-matrix product routines for pairs of MPIAIJ matrices
4: C = A^T * B
5: The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
6: */
7: #include <../src/mat/impls/aij/seq/aij.h>
8: #include <../src/mat/impls/aij/mpi/mpiaij.h>
9: #include <../src/mat/impls/dense/mpi/mpidense.h>
11: PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data)
12: {
13: PetscErrorCode ierr;
14: Mat_MatTransMatMult *atb = (Mat_MatTransMatMult*)data;
17: MatDestroy(&atb->mA);
18: VecDestroy(&atb->bt);
19: VecDestroy(&atb->ct);
20: PetscFree(atb);
21: return(0);
22: }
24: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat,Mat,Mat);
26: PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A,Mat B,PetscReal fill,Mat C)
27: {
28: PetscErrorCode ierr;
29: Mat_MatTransMatMult *atb;
30: PetscBool cisdense;
33: MatCheckProduct(C,4);
34: if (C->product->data) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Extra product struct not empty");
36: /* create output dense matrix C = A^T*B */
37: MatSetSizes(C,A->cmap->n,B->cmap->n,A->cmap->N,B->cmap->N);
38: PetscObjectTypeCompareAny((PetscObject)C,&cisdense,MATMPIDENSE,MATMPIDENSECUDA,"");
39: if (!cisdense) {
40: MatSetType(C,((PetscObject)B)->type_name);
41: }
42: MatSetUp(C);
44: /* create additional data structure for the product */
45: PetscNew(&atb);
46: if (B->cmap->N) {
47: MatCreateMAIJ(A,B->cmap->N,&atb->mA);
48: if (!atb->mA->assembled) {
49: MatAssemblyBegin(atb->mA,MAT_FINAL_ASSEMBLY);
50: MatAssemblyEnd(atb->mA,MAT_FINAL_ASSEMBLY);
51: }
52: MatCreateVecs(atb->mA,&atb->ct,&atb->bt);
53: }
54: C->product->data = atb;
55: C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;
57: C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
58: return(0);
59: }
61: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A,Mat B,Mat C)
62: {
63: PetscErrorCode ierr;
64: const PetscScalar *Barray,*ctarray;
65: PetscScalar *Carray,*btarray;
66: Mat_MPIDense *b=(Mat_MPIDense*)B->data,*c=(Mat_MPIDense*)C->data;
67: Mat_SeqDense *bseq=(Mat_SeqDense*)(b->A)->data,*cseq=(Mat_SeqDense*)(c->A)->data;
68: PetscInt i,j,m=A->rmap->n,n=A->cmap->n,ldb=bseq->lda,BN=B->cmap->N,ldc=cseq->lda;
69: Mat_MatTransMatMult *atb;
70: Vec bt,ct;
73: MatCheckProduct(C,3);
74: atb=(Mat_MatTransMatMult *)C->product->data;
75: if (!atb) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_PLIB,"Missing product struct");
76: if (!BN) {
77: MatAssemblyBegin(C,MAT_FINAL_ASSEMBLY);
78: MatAssemblyEnd(C,MAT_FINAL_ASSEMBLY);
79: return(0);
80: }
81: bt = atb->bt;
82: ct = atb->ct;
83: /* transpose local arry of B, then copy it to vector bt */
84: MatDenseGetArrayRead(B,&Barray);
85: VecGetArray(bt,&btarray);
87: for (j=0; j<BN; j++) {
88: for (i=0; i<m; i++) btarray[i*BN + j] = Barray[j*ldb + i];
89: }
90: VecRestoreArray(bt,&btarray);
91: MatDenseRestoreArrayRead(B,&Barray);
93: /* compute ct = mA^T * cb */
94: MatMultTranspose(atb->mA,bt,ct);
96: /* transpose local array of ct to matrix C */
97: MatDenseGetArray(C,&Carray);
98: VecGetArrayRead(ct,&ctarray);
100: for (j=0; j<BN; j++) {
101: for (i=0; i<n; i++) Carray[j*ldc + i] = ctarray[i*BN + j];
102: }
103: VecRestoreArrayRead(ct,&ctarray);
104: MatDenseRestoreArray(C,&Carray);
105: MatAssemblyBegin(C,MAT_FINAL_ASSEMBLY);
106: MatAssemblyEnd(C,MAT_FINAL_ASSEMBLY);
107: return(0);
108: }