Actual source code: mpimatmatmatmult.c

petsc-3.13.6 2020-09-29
Report Typos and Errors
  1: /*
  2:   Defines matrix-matrix-matrix product routines for MPIAIJ matrices
  3:           D = A * B * C
  4: */
  5:  #include <../src/mat/impls/aij/mpi/mpiaij.h>

  7: #if defined(PETSC_HAVE_HYPRE)
  8: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,PetscReal,Mat);
  9: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,Mat);

 11: PETSC_INTERN PetscErrorCode MatProductNumeric_ABC_Transpose_AIJ_AIJ(Mat RAP)
 12: {
 14:   Mat_Product    *product = RAP->product;
 15:   Mat            Rt,R=product->A,A=product->B,P=product->C;

 18:   MatTransposeGetMat(R,&Rt);
 19:   MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,RAP);
 20:   return(0);
 21: }

 23: PETSC_INTERN PetscErrorCode MatProductSymbolic_ABC_Transpose_AIJ_AIJ(Mat RAP)
 24: {
 26:   Mat_Product    *product = RAP->product;
 27:   Mat            Rt,R=product->A,A=product->B,P=product->C;
 28:   PetscBool      flg;

 31:   /* local sizes of matrices will be checked by the calling subroutines */
 32:   MatTransposeGetMat(R,&Rt);
 33:   PetscObjectTypeCompareAny((PetscObject)Rt,&flg,MATSEQAIJ,MATSEQAIJMKL,MATMPIAIJ,NULL);
 34:   if (!flg) SETERRQ1(PetscObjectComm((PetscObject)Rt),PETSC_ERR_SUP,"Not for matrix type %s",((PetscObject)Rt)->type_name);
 35:   MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,product->fill,RAP);
 36:   RAP->ops->productnumeric = MatProductNumeric_ABC_Transpose_AIJ_AIJ;
 37:   return(0);
 38: }

 40: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose_AIJ_AIJ(Mat C)
 41: {
 42:   Mat_Product *product = C->product;

 45:   if (product->type == MATPRODUCT_ABC) {
 46:     C->ops->productsymbolic = MatProductSymbolic_ABC_Transpose_AIJ_AIJ;
 47:   } else SETERRQ1(PetscObjectComm((PetscObject)C),PETSC_ERR_SUP,"MatProduct type %s is not supported for Transpose, AIJ and AIJ matrices",MatProductTypes[product->type]);
 48:   return(0);
 49: }
 50: #endif

 52: PetscErrorCode MatFreeIntermediateDataStructures_MPIAIJ_BC(Mat ABC)
 53: {
 54:   Mat_MPIAIJ        *a = (Mat_MPIAIJ*)ABC->data;
 55:   Mat_MatMatMatMult *matmatmatmult = a->matmatmatmult;
 56:   PetscErrorCode    ierr;

 59:   if (!matmatmatmult) return(0);

 61:   MatDestroy(&matmatmatmult->BC);
 62:   ABC->ops->destroy = matmatmatmult->destroy;
 63:   PetscFree(a->matmatmatmult);
 64:   return(0);
 65: }

 67: PetscErrorCode MatDestroy_MPIAIJ_MatMatMatMult(Mat A)
 68: {
 69:   PetscErrorCode    ierr;

 72:   (*A->ops->freeintermediatedatastructures)(A);
 73:   (*A->ops->destroy)(A);
 74:   return(0);
 75: }

 77: PetscErrorCode MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,PetscReal fill,Mat D)
 78: {
 80:   Mat            BC;
 81:   PetscBool      scalable;
 82:   Mat_Product    *product = D->product;

 85:   MatCreate(PetscObjectComm((PetscObject)A),&BC);
 86:   if (product) {
 87:     PetscStrcmp(product->alg,"scalable",&scalable);
 88:   } else SETERRQ(PetscObjectComm((PetscObject)D),PETSC_ERR_ARG_NULL,"Call MatProductCreate() first");

 90:   if (scalable) {
 91:     MatMatMultSymbolic_MPIAIJ_MPIAIJ(B,C,fill,BC);
 92:     MatZeroEntries(BC); /* initialize value entries of BC */
 93:     MatMatMultSymbolic_MPIAIJ_MPIAIJ(A,BC,fill,D);
 94:   } else {
 95:     MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(B,C,fill,BC);
 96:     MatZeroEntries(BC); /* initialize value entries of BC */
 97:     MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(A,BC,fill,D);
 98:   }
 99:   product->Dwork = BC;

101:   D->ops->matmatmultnumeric = MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ;
102:   D->ops->freeintermediatedatastructures = MatFreeIntermediateDataStructures_MPIAIJ_BC;
103:   return(0);
104: }

106: PetscErrorCode MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,Mat D)
107: {
109:   Mat_Product    *product = D->product;
110:   Mat            BC = product->Dwork;

113:   (BC->ops->matmultnumeric)(B,C,BC);
114:   (D->ops->matmultnumeric)(A,BC,D);
115:   return(0);
116: }

118: /* ----------------------------------------------------- */
119: PetscErrorCode MatDestroy_MPIAIJ_RARt(Mat C)
120: {
122:   Mat_MPIAIJ     *c    = (Mat_MPIAIJ*)C->data;
123:   Mat_RARt       *rart = c->rart;

126:   MatDestroy(&rart->Rt);

128:   C->ops->destroy = rart->destroy;
129:   if (C->ops->destroy) {
130:     (*C->ops->destroy)(C);
131:   }
132:   PetscFree(rart);
133:   return(0);
134: }

136: PetscErrorCode MatProductNumeric_RARt_MPIAIJ_MPIAIJ(Mat C)
137: {
139:   Mat_MPIAIJ     *c = (Mat_MPIAIJ*)C->data;
140:   Mat_RARt       *rart = c->rart;
141:   Mat_Product    *product = C->product;
142:   Mat            A=product->A,R=product->B,Rt=rart->Rt;

145:   MatTranspose(R,MAT_REUSE_MATRIX,&Rt);
146:   (C->ops->matmatmultnumeric)(R,A,Rt,C);
147:   return(0);
148: }

150: PetscErrorCode MatProductSymbolic_RARt_MPIAIJ_MPIAIJ(Mat C)
151: {
152:   PetscErrorCode      ierr;
153:   Mat_Product         *product = C->product;
154:   Mat                 A=product->A,R=product->B,Rt;
155:   PetscReal           fill=product->fill;
156:   Mat_RARt            *rart;
157:   Mat_MPIAIJ          *c;

160:   MatTranspose(R,MAT_INITIAL_MATRIX,&Rt);
161:   /* product->Dwork is used to store A*Rt in MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ() */
162:   MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(R,A,Rt,fill,C);
163:   C->ops->productnumeric = MatProductNumeric_RARt_MPIAIJ_MPIAIJ;

165:   /* create a supporting struct */
166:   PetscNew(&rart);
167:   c        = (Mat_MPIAIJ*)C->data;
168:   c->rart  = rart;
169:   rart->Rt = Rt;
170:   rart->destroy   = C->ops->destroy;
171:   C->ops->destroy = MatDestroy_MPIAIJ_RARt;
172:   return(0);
173: }