Actual source code: lrc.c


  2: #include <petsc/private/matimpl.h>

  4: PETSC_EXTERN PetscErrorCode VecGetRootType_Private(Vec,VecType*);

  6: typedef struct {
  7:   Mat A;           /* sparse matrix */
  8:   Mat U,V;         /* dense tall-skinny matrices */
  9:   Vec c;           /* sequential vector containing the diagonal of C */
 10:   Vec work1,work2; /* sequential vectors that hold partial products */
 11:   Vec xl,yl;       /* auxiliary sequential vectors for matmult operation */
 12: } Mat_LRC;

 14: static PetscErrorCode MatMult_LRC_kernel(Mat N,Vec x,Vec y,PetscBool transpose)
 15: {
 16:   Mat_LRC        *Na = (Mat_LRC*)N->data;
 17:   PetscMPIInt    size;
 18:   Mat            U,V;

 20:   U = transpose ? Na->V : Na->U;
 21:   V = transpose ? Na->U : Na->V;
 22:   MPI_Comm_size(PetscObjectComm((PetscObject)N),&size);
 23:   if (size == 1) {
 24:     MatMultHermitianTranspose(V,x,Na->work1);
 25:     if (Na->c) {
 26:       VecPointwiseMult(Na->work1,Na->c,Na->work1);
 27:     }
 28:     if (Na->A) {
 29:       if (transpose) {
 30:         MatMultTranspose(Na->A,x,y);
 31:       } else {
 32:         MatMult(Na->A,x,y);
 33:       }
 34:       MatMultAdd(U,Na->work1,y,y);
 35:     } else {
 36:       MatMult(U,Na->work1,y);
 37:     }
 38:   } else {
 39:     Mat               Uloc,Vloc;
 40:     Vec               yl,xl;
 41:     const PetscScalar *w1;
 42:     PetscScalar       *w2;
 43:     PetscInt          nwork;
 44:     PetscMPIInt       mpinwork;

 46:     xl = transpose ? Na->yl : Na->xl;
 47:     yl = transpose ? Na->xl : Na->yl;
 48:     VecGetLocalVector(y,yl);
 49:     MatDenseGetLocalMatrix(U,&Uloc);
 50:     MatDenseGetLocalMatrix(V,&Vloc);

 52:     /* multiply the local part of V with the local part of x */
 53:     VecGetLocalVectorRead(x,xl);
 54:     MatMultHermitianTranspose(Vloc,xl,Na->work1);
 55:     VecRestoreLocalVectorRead(x,xl);

 57:     /* form the sum of all the local multiplies: this is work2 = V'*x =
 58:        sum_{all processors} work1 */
 59:     VecGetArrayRead(Na->work1,&w1);
 60:     VecGetArrayWrite(Na->work2,&w2);
 61:     VecGetLocalSize(Na->work1,&nwork);
 62:     PetscMPIIntCast(nwork,&mpinwork);
 63:     MPIU_Allreduce(w1,w2,mpinwork,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)N));
 64:     VecRestoreArrayRead(Na->work1,&w1);
 65:     VecRestoreArrayWrite(Na->work2,&w2);

 67:     if (Na->c) {  /* work2 = C*work2 */
 68:       VecPointwiseMult(Na->work2,Na->c,Na->work2);
 69:     }

 71:     if (Na->A) {
 72:       /* form y = A*x or A^t*x */
 73:       if (transpose) {
 74:         MatMultTranspose(Na->A,x,y);
 75:       } else {
 76:         MatMult(Na->A,x,y);
 77:       }
 78:       /* multiply-add y = y + U*work2 */
 79:       MatMultAdd(Uloc,Na->work2,yl,yl);
 80:     } else {
 81:       /* multiply y = U*work2 */
 82:       MatMult(Uloc,Na->work2,yl);
 83:     }

 85:     VecRestoreLocalVector(y,yl);
 86:   }
 87:   return 0;
 88: }

 90: static PetscErrorCode MatMult_LRC(Mat N,Vec x,Vec y)
 91: {
 92:   MatMult_LRC_kernel(N,x,y,PETSC_FALSE);
 93:   return 0;
 94: }

 96: static PetscErrorCode MatMultTranspose_LRC(Mat N,Vec x,Vec y)
 97: {
 98:   MatMult_LRC_kernel(N,x,y,PETSC_TRUE);
 99:   return 0;
100: }

102: static PetscErrorCode MatDestroy_LRC(Mat N)
103: {
104:   Mat_LRC        *Na = (Mat_LRC*)N->data;

106:   MatDestroy(&Na->A);
107:   MatDestroy(&Na->U);
108:   MatDestroy(&Na->V);
109:   VecDestroy(&Na->c);
110:   VecDestroy(&Na->work1);
111:   VecDestroy(&Na->work2);
112:   VecDestroy(&Na->xl);
113:   VecDestroy(&Na->yl);
114:   PetscFree(N->data);
115:   PetscObjectComposeFunction((PetscObject)N,"MatLRCGetMats_C",NULL);
116:   return 0;
117: }

119: static PetscErrorCode MatLRCGetMats_LRC(Mat N,Mat *A,Mat *U,Vec *c,Mat *V)
120: {
121:   Mat_LRC *Na = (Mat_LRC*)N->data;

123:   if (A) *A = Na->A;
124:   if (U) *U = Na->U;
125:   if (c) *c = Na->c;
126:   if (V) *V = Na->V;
127:   return 0;
128: }

130: /*@
131:    MatLRCGetMats - Returns the constituents of an LRC matrix

133:    Collective on Mat

135:    Input Parameter:
136: .  N - matrix of type LRC

138:    Output Parameters:
139: +  A - the (sparse) matrix
140: .  U - first dense rectangular (tall and skinny) matrix
141: .  c - a sequential vector containing the diagonal of C
142: -  V - second dense rectangular (tall and skinny) matrix

144:    Note:
145:    The returned matrices need not be destroyed by the caller.

147:    Level: intermediate

149: .seealso: MatCreateLRC()
150: @*/
151: PetscErrorCode MatLRCGetMats(Mat N,Mat *A,Mat *U,Vec *c,Mat *V)
152: {
153:   PetscUseMethod(N,"MatLRCGetMats_C",(Mat,Mat*,Mat*,Vec*,Mat*),(N,A,U,c,V));
154:   return 0;
155: }

157: /*@
158:    MatCreateLRC - Creates a new matrix object that behaves like A + U*C*V'

160:    Collective on Mat

162:    Input Parameters:
163: +  A    - the (sparse) matrix (can be NULL)
164: .  U, V - two dense rectangular (tall and skinny) matrices
165: -  c    - a vector containing the diagonal of C (can be NULL)

167:    Output Parameter:
168: .  N    - the matrix that represents A + U*C*V'

170:    Notes:
171:    The matrix A + U*C*V' is not formed! Rather the new matrix
172:    object performs the matrix-vector product by first multiplying by
173:    A and then adding the other term.

175:    C is a diagonal matrix (represented as a vector) of order k,
176:    where k is the number of columns of both U and V.

178:    If A is NULL then the new object behaves like a low-rank matrix U*C*V'.

180:    Use V=U (or V=NULL) for a symmetric low-rank correction, A + U*C*U'.

182:    If c is NULL then the low-rank correction is just U*V'.
183:    If a sequential c vector is used for a parallel matrix,
184:    PETSc assumes that the values of the vector are consistently set across processors.

186:    Level: intermediate

188: .seealso: MatLRCGetMats()
189: @*/
190: PetscErrorCode MatCreateLRC(Mat A,Mat U,Vec c,Mat V,Mat *N)
191: {
192:   PetscBool      match;
193:   PetscInt       m,n,k,m1,n1,k1;
194:   Mat_LRC        *Na;
195:   Mat            Uloc;
196:   PetscMPIInt    size, csize = 0;

201:   if (V) {
204:   }

207:   if (!V) V = U;
208:   PetscObjectBaseTypeCompareAny((PetscObject)U,&match,MATSEQDENSE,MATMPIDENSE,"");
210:   PetscObjectBaseTypeCompareAny((PetscObject)V,&match,MATSEQDENSE,MATMPIDENSE,"");
212:   PetscStrcmp(U->defaultvectype,V->defaultvectype,&match);
214:   if (A) {
215:     PetscStrcmp(A->defaultvectype,U->defaultvectype,&match);
217:   }

219:   MPI_Comm_size(PetscObjectComm((PetscObject)U),&size);
220:   MatGetSize(U,NULL,&k);
221:   MatGetSize(V,NULL,&k1);
223:   MatGetLocalSize(U,&m,NULL);
224:   MatGetLocalSize(V,&n,NULL);
225:   if (A) {
226:     MatGetLocalSize(A,&m1,&n1);
229:   }
230:   if (c) {
231:     MPI_Comm_size(PetscObjectComm((PetscObject)c),&csize);
232:     VecGetSize(c,&k1);
235:   }

237:   MatCreate(PetscObjectComm((PetscObject)U),N);
238:   MatSetSizes(*N,m,n,PETSC_DECIDE,PETSC_DECIDE);
239:   MatSetVecType(*N,U->defaultvectype);
240:   PetscObjectChangeTypeName((PetscObject)*N,MATLRC);
241:   /* Flag matrix as symmetric if A is symmetric and U == V */
242:   MatSetOption(*N,MAT_SYMMETRIC,(PetscBool)((A ? A->symmetric : PETSC_TRUE) && U == V));

244:   PetscNewLog(*N,&Na);
245:   (*N)->data = (void*)Na;
246:   Na->A      = A;
247:   Na->U      = U;
248:   Na->c      = c;
249:   Na->V      = V;

251:   PetscObjectReference((PetscObject)A);
252:   PetscObjectReference((PetscObject)Na->U);
253:   PetscObjectReference((PetscObject)Na->V);
254:   PetscObjectReference((PetscObject)c);

256:   MatDenseGetLocalMatrix(Na->U,&Uloc);
257:   MatCreateVecs(Uloc,&Na->work1,NULL);
258:   if (size != 1) {
259:     Mat Vloc;

261:     if (Na->c && csize != 1) { /* scatter parallel vector to sequential */
262:       VecScatter sct;

264:       VecScatterCreateToAll(Na->c,&sct,&c);
265:       VecScatterBegin(sct,Na->c,c,INSERT_VALUES,SCATTER_FORWARD);
266:       VecScatterEnd(sct,Na->c,c,INSERT_VALUES,SCATTER_FORWARD);
267:       VecScatterDestroy(&sct);
268:       VecDestroy(&Na->c);
269:       PetscLogObjectParent((PetscObject)*N,(PetscObject)c);
270:       Na->c = c;
271:     }
272:     MatDenseGetLocalMatrix(Na->V,&Vloc);
273:     VecDuplicate(Na->work1,&Na->work2);
274:     MatCreateVecs(Vloc,NULL,&Na->xl);
275:     MatCreateVecs(Uloc,NULL,&Na->yl);
276:   }
277:   PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->work1);
278:   PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->work1);
279:   PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->xl);
280:   PetscLogObjectParent((PetscObject)*N,(PetscObject)Na->yl);

282:   /* Internally create a scaling vector if roottypes do not match */
283:   if (Na->c) {
284:     VecType rt1,rt2;

286:     VecGetRootType_Private(Na->work1,&rt1);
287:     VecGetRootType_Private(Na->c,&rt2);
288:     PetscStrcmp(rt1,rt2,&match);
289:     if (!match) {
290:       VecDuplicate(Na->c,&c);
291:       VecCopy(Na->c,c);
292:       VecDestroy(&Na->c);
293:       PetscLogObjectParent((PetscObject)*N,(PetscObject)c);
294:       Na->c = c;
295:     }
296:   }

298:   (*N)->ops->destroy       = MatDestroy_LRC;
299:   (*N)->ops->mult          = MatMult_LRC;
300:   (*N)->ops->multtranspose = MatMultTranspose_LRC;

302:   (*N)->assembled    = PETSC_TRUE;
303:   (*N)->preallocated = PETSC_TRUE;

305:   PetscObjectComposeFunction((PetscObject)(*N),"MatLRCGetMats_C",MatLRCGetMats_LRC);
306:   MatSetUp(*N);
307:   return 0;
308: }