Actual source code: lmvmimpl.c

  1: #include <petscdevice.h>
  2: #include <../src/ksp/ksp/utils/lmvm/lmvm.h>
  3: #include <petsc/private/deviceimpl.h>
  4: #include "blas_cyclic/blas_cyclic.h"
  5: #include "rescale/symbrdnrescale.h"

  7: PetscLogEvent MATLMVM_Update;

  9: static PetscBool MatLMVMPackageInitialized = PETSC_FALSE;

 11: static PetscErrorCode MatLMVMPackageInitialize(void)
 12: {
 13:   PetscFunctionBegin;
 14:   if (MatLMVMPackageInitialized) PetscFunctionReturn(PETSC_SUCCESS);
 15:   MatLMVMPackageInitialized = PETSC_TRUE;
 16:   PetscCall(PetscLogEventRegister("AXPBYCyclic", MAT_CLASSID, &AXPBY_Cyc));
 17:   PetscCall(PetscLogEventRegister("DMVCyclic", MAT_CLASSID, &DMV_Cyc));
 18:   PetscCall(PetscLogEventRegister("DSVCyclic", MAT_CLASSID, &DSV_Cyc));
 19:   PetscCall(PetscLogEventRegister("TRSVCyclic", MAT_CLASSID, &TRSV_Cyc));
 20:   PetscCall(PetscLogEventRegister("GEMVCyclic", MAT_CLASSID, &GEMV_Cyc));
 21:   PetscCall(PetscLogEventRegister("HEMVCyclic", MAT_CLASSID, &HEMV_Cyc));
 22:   PetscCall(PetscLogEventRegister("LMBasisGEMM", MAT_CLASSID, &LMBASIS_GEMM));
 23:   PetscCall(PetscLogEventRegister("LMBasisGEMV", MAT_CLASSID, &LMBASIS_GEMV));
 24:   PetscCall(PetscLogEventRegister("LMBasisGEMVH", MAT_CLASSID, &LMBASIS_GEMVH));
 25:   PetscCall(PetscLogEventRegister("LMProdsMult", MAT_CLASSID, &LMPROD_Mult));
 26:   PetscCall(PetscLogEventRegister("LMProdsSolve", MAT_CLASSID, &LMPROD_Solve));
 27:   PetscCall(PetscLogEventRegister("LMProdsUpdate", MAT_CLASSID, &LMPROD_Update));
 28:   PetscCall(PetscLogEventRegister("MatLMVMUpdate", MAT_CLASSID, &MATLMVM_Update));
 29:   PetscCall(PetscLogEventRegister("SymBrdnRescale", MAT_CLASSID, &SBRDN_Rescale));
 30:   PetscFunctionReturn(PETSC_SUCCESS);
 31: }

 33: const char *const MatLMVMMultAlgorithms[] = {
 34:   "recursive", "dense", "compact_dense", "MatLMVMMatvecTypes", "MATLMVM_MATVEC_", NULL,
 35: };

 37: PetscBool  ByrdNocedalSchnabelCite       = PETSC_FALSE;
 38: const char ByrdNocedalSchnabelCitation[] = "@article{Byrd1994,"
 39:                                            "  title = {Representations of quasi-Newton matrices and their use in limited memory methods},"
 40:                                            "  volume = {63},"
 41:                                            "  ISSN = {1436-4646},"
 42:                                            "  url = {http://dx.doi.org/10.1007/BF01582063},"
 43:                                            "  DOI = {10.1007/bf01582063},"
 44:                                            "  number = {1-3},"
 45:                                            "  journal = {Mathematical Programming},"
 46:                                            "  publisher = {Springer Science and Business Media LLC},"
 47:                                            "  author = {Byrd,  Richard H. and Nocedal,  Jorge and Schnabel,  Robert B.},"
 48:                                            "  year = {1994},"
 49:                                            "  month = jan,"
 50:                                            "  pages = {129-156}"
 51:                                            "}\n";

 53: PETSC_INTERN PetscErrorCode MatReset_LMVM(Mat B, MatLMVMResetMode mode)
 54: {
 55:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

 57:   PetscFunctionBegin;
 58:   lmvm->k        = 0;
 59:   lmvm->prev_set = PETSC_FALSE;
 60:   lmvm->shift    = 0.0;
 61:   if (MatLMVMResetClearsBases(mode)) {
 62:     for (PetscInt i = 0; i < LMBASIS_END; i++) PetscCall(LMBasisDestroy(&lmvm->basis[i]));
 63:     for (PetscInt k = 0; k < LMBLOCK_END; k++) {
 64:       for (PetscInt i = 0; i < LMBASIS_END; i++) {
 65:         for (PetscInt j = 0; j < LMBASIS_END; j++) PetscCall(LMProductsDestroy(&lmvm->products[k][i][j]));
 66:       }
 67:     }
 68:     B->preallocated = PETSC_FALSE; // MatSetUp() needs to be run to create at least the S and Y bases
 69:   } else {
 70:     for (PetscInt i = 0; i < LMBASIS_END; i++) PetscCall(LMBasisReset(lmvm->basis[i]));
 71:     for (PetscInt k = 0; k < LMBLOCK_END; k++) {
 72:       for (PetscInt i = 0; i < LMBASIS_END; i++) {
 73:         for (PetscInt j = 0; j < LMBASIS_END; j++) PetscCall(LMProductsReset(lmvm->products[k][i][j]));
 74:       }
 75:     }
 76:   }
 77:   if (MatLMVMResetClearsJ0(mode)) PetscCall(MatLMVMClearJ0(B));
 78:   if (MatLMVMResetClearsVecs(mode)) {
 79:     PetscCall(VecDestroy(&lmvm->Xprev));
 80:     PetscCall(VecDestroy(&lmvm->Fprev));
 81:     B->preallocated = PETSC_FALSE; // MatSetUp() needs to be run to create these vecs
 82:   }
 83:   if (MatLMVMResetClearsAll(mode)) {
 84:     lmvm->nupdates = 0;
 85:     lmvm->nrejects = 0;
 86:   }
 87:   PetscFunctionReturn(PETSC_SUCCESS);
 88: }

 90: PETSC_INTERN PetscErrorCode MatLMVMAllocateBases(Mat B)
 91: {
 92:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

 94:   PetscFunctionBegin;
 95:   PetscCheck(lmvm->Xprev != NULL && lmvm->Fprev != NULL, PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_WRONGSTATE, "Must allocate Xprev and Fprev before allocating bases");
 96:   if (!lmvm->basis[LMBASIS_S]) PetscCall(LMBasisCreate(lmvm->Xprev, lmvm->m, &lmvm->basis[LMBASIS_S]));
 97:   if (!lmvm->basis[LMBASIS_Y]) PetscCall(LMBasisCreate(lmvm->Fprev, lmvm->m, &lmvm->basis[LMBASIS_Y]));
 98:   PetscFunctionReturn(PETSC_SUCCESS);
 99: }

101: PETSC_INTERN PetscErrorCode MatLMVMAllocateVecs(Mat B)
102: {
103:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

105:   PetscFunctionBegin;
106:   if (!lmvm->Xprev) PetscCall(MatCreateVecs(B, &lmvm->Xprev, NULL));
107:   if (!lmvm->Fprev) PetscCall(MatCreateVecs(B, NULL, &lmvm->Fprev));
108:   PetscFunctionReturn(PETSC_SUCCESS);
109: }

111: PETSC_INTERN PetscErrorCode MatAllocate_LMVM(Mat B, Vec X, Vec F)
112: {
113:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
114:   PetscBool same;
115:   VecType   vtype, Bvtype;

117:   PetscFunctionBegin;
118:   PetscCall(MatLMVMUseVecLayoutsIfCompatible(B, X, F));
119:   PetscCall(VecGetType(X, &vtype));
120:   PetscCall(MatGetVecType(B, &Bvtype));
121:   PetscCall(PetscStrcmp(vtype, Bvtype, &same));
122:   if (!same) {
123:     /* Given X vector has a different type than allocated X-type data structures.
124:        We need to destroy all of this and duplicate again out of the given vector. */
125:     PetscCall(MatLMVMReset_Internal(B, MAT_LMVM_RESET_BASES | MAT_LMVM_RESET_VECS));
126:     PetscCall(MatSetVecType(B, vtype));
127:     if (lmvm->created_J0) PetscCall(MatSetVecType(lmvm->J0, vtype));
128:   }
129:   PetscCall(MatLMVMAllocateVecs(B));
130:   PetscFunctionReturn(PETSC_SUCCESS);
131: }

133: PetscErrorCode MatUpdateKernel_LMVM(Mat B, Vec S, Vec Y)
134: {
135:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
136:   Vec       s_k, y_k;

138:   PetscFunctionBegin;
139:   PetscCall(LMBasisGetNextVec(lmvm->basis[LMBASIS_S], &s_k));
140:   PetscCall(VecCopy(S, s_k));
141:   PetscCall(LMBasisRestoreNextVec(lmvm->basis[LMBASIS_S], &s_k));

143:   PetscCall(LMBasisGetNextVec(lmvm->basis[LMBASIS_Y], &y_k));
144:   PetscCall(VecCopy(Y, y_k));
145:   PetscCall(LMBasisRestoreNextVec(lmvm->basis[LMBASIS_Y], &y_k));
146:   lmvm->nupdates++;
147:   lmvm->k++;
148:   PetscAssert(lmvm->k == lmvm->basis[LMBASIS_S]->k, PetscObjectComm((PetscObject)B), PETSC_ERR_PLIB, "Basis S and Mat B out of sync");
149:   PetscAssert(lmvm->k == lmvm->basis[LMBASIS_Y]->k, PetscObjectComm((PetscObject)B), PETSC_ERR_PLIB, "Basis Y and Mat B out of sync");
150:   PetscFunctionReturn(PETSC_SUCCESS);
151: }

153: PetscErrorCode MatUpdate_LMVM(Mat B, Vec X, Vec F)
154: {
155:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

157:   PetscFunctionBegin;
158:   if (!lmvm->m) PetscFunctionReturn(PETSC_SUCCESS);
159:   if (lmvm->prev_set) {
160:     /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
161:     PetscCall(VecAXPBY(lmvm->Xprev, 1.0, -1.0, X));
162:     PetscCall(VecAXPBY(lmvm->Fprev, 1.0, -1.0, F));
163:     /* Update S and Y */
164:     PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
165:   }

167:   /* Save the solution and function to be used in the next update */
168:   PetscCall(VecCopy(X, lmvm->Xprev));
169:   PetscCall(VecCopy(F, lmvm->Fprev));
170:   lmvm->prev_set = PETSC_TRUE;
171:   PetscFunctionReturn(PETSC_SUCCESS);
172: }

174: static PetscErrorCode MatMultAdd_LMVM(Mat B, Vec X, Vec Y, Vec Z)
175: {
176:   PetscFunctionBegin;
177:   PetscCall(MatMult(B, X, Z));
178:   PetscCall(VecAXPY(Z, 1.0, Y));
179:   PetscFunctionReturn(PETSC_SUCCESS);
180: }

182: static PetscErrorCode MatMult_LMVM(Mat B, Vec X, Vec Y)
183: {
184:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

186:   PetscFunctionBegin;
187:   PetscCall((*lmvm->ops->mult)(B, X, Y));
188:   if (lmvm->shift != 0.0) PetscCall(VecAXPY(Y, lmvm->shift, X));
189:   PetscFunctionReturn(PETSC_SUCCESS);
190: }

192: static PetscErrorCode MatMultHermitianTranspose_LMVM(Mat B, Vec X, Vec Y)
193: {
194:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

196:   PetscFunctionBegin;
197:   PetscCall((*lmvm->ops->multht)(B, X, Y));
198:   if (lmvm->shift != 0.0) PetscCall(VecAXPY(Y, PetscConj(lmvm->shift), X));
199:   PetscFunctionReturn(PETSC_SUCCESS);
200: }

202: static PetscErrorCode MatSolve_LMVM(Mat B, Vec x, Vec y)
203: {
204:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

206:   PetscFunctionBegin;
207:   PetscCheck(lmvm->shift == 0.0, PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_WRONGSTATE, "Cannot solve a MatLMVM when it has a nonzero shift");
208:   PetscCall((*lmvm->ops->solve)(B, x, y));
209:   PetscFunctionReturn(PETSC_SUCCESS);
210: }

212: static PetscErrorCode MatSolveHermitianTranspose_LMVM(Mat B, Vec x, Vec y)
213: {
214:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

216:   PetscFunctionBegin;
217:   PetscCheck(lmvm->shift == 0.0, PetscObjectComm((PetscObject)B), PETSC_ERR_ARG_WRONGSTATE, "Cannot solve a MatLMVM when it has a nonzero shift");
218:   PetscCall((*lmvm->ops->solveht)(B, x, y));
219:   PetscFunctionReturn(PETSC_SUCCESS);
220: }

222: static PetscErrorCode MatSolveTranspose_LMVM(Mat B, Vec x, Vec y)
223: {
224:   PetscFunctionBegin;
225:   if (!PetscDefined(USE_COMPLEX)) {
226:     PetscCall(MatSolveHermitianTranspose_LMVM(B, x, y));
227:   } else {
228:     Vec x_conj;
229:     PetscCall(VecDuplicate(x, &x_conj));
230:     PetscCall(VecCopy(x, x_conj));
231:     PetscCall(VecConjugate(x_conj));
232:     PetscCall(MatSolveHermitianTranspose_LMVM(B, x_conj, y));
233:     PetscCall(VecDestroy(&x_conj));
234:     PetscCall(VecConjugate(y));
235:   }
236:   PetscFunctionReturn(PETSC_SUCCESS);
237: }

239: // MatCopy() calls MatCheckPreallocated(), so B will have Xprev, Fprev, LMBASIS_S, and LMBASIS_Y
240: static PetscErrorCode MatCopy_LMVM(Mat B, Mat M, MatStructure str)
241: {
242:   Mat_LMVM *bctx = (Mat_LMVM *)B->data;
243:   Mat_LMVM *mctx;
244:   Mat       J0_copy;

246:   PetscFunctionBegin;
247:   if (str == DIFFERENT_NONZERO_PATTERN) {
248:     PetscCall(MatLMVMReset(M, PETSC_TRUE));
249:     PetscCall(MatLMVMAllocate(M, bctx->Xprev, bctx->Fprev));
250:   } else MatCheckSameSize(B, 1, M, 2);

252:   mctx = (Mat_LMVM *)M->data;
253:   PetscCall(MatDuplicate(bctx->J0, MAT_COPY_VALUES, &J0_copy));
254:   PetscCall(MatLMVMSetJ0(M, J0_copy));
255:   PetscCall(MatDestroy(&J0_copy));
256:   mctx->nupdates = bctx->nupdates;
257:   mctx->nrejects = bctx->nrejects;
258:   mctx->k        = bctx->k;
259:   PetscCall(MatLMVMAllocateVecs(M));
260:   PetscCall(VecCopy(bctx->Xprev, mctx->Xprev));
261:   PetscCall(VecCopy(bctx->Fprev, mctx->Fprev));
262:   PetscCall(MatLMVMAllocateBases(M));
263:   PetscCall(LMBasisCopy(bctx->basis[LMBASIS_S], mctx->basis[LMBASIS_S]));
264:   PetscCall(LMBasisCopy(bctx->basis[LMBASIS_Y], mctx->basis[LMBASIS_Y]));
265:   mctx->do_not_cache_J0_products = bctx->do_not_cache_J0_products;
266:   mctx->cache_gradient_products  = bctx->cache_gradient_products;
267:   mctx->mult_alg                 = bctx->mult_alg;
268:   if (mctx->ops->setmultalgorithm) PetscCall((*mctx->ops->setmultalgorithm)(M));
269:   if (bctx->ops->copy) PetscCall((*bctx->ops->copy)(B, M, str));
270:   PetscFunctionReturn(PETSC_SUCCESS);
271: }

273: static PetscErrorCode MatDuplicate_LMVM(Mat B, MatDuplicateOption op, Mat *mat)
274: {
275:   Mat_LMVM *bctx = (Mat_LMVM *)B->data;
276:   Mat_LMVM *mctx;
277:   MatType   lmvmType;
278:   Mat       A;

280:   PetscFunctionBegin;
281:   PetscCall(MatGetType(B, &lmvmType));
282:   PetscCall(MatCreate(PetscObjectComm((PetscObject)B), mat));
283:   PetscCall(MatSetType(*mat, lmvmType));

285:   A       = *mat;
286:   mctx    = (Mat_LMVM *)A->data;
287:   mctx->m = bctx->m;
288:   if (bctx->J0ksp) {
289:     PetscReal rtol, atol, dtol;
290:     PetscInt  max_it;

292:     PetscCall(KSPGetTolerances(bctx->J0ksp, &rtol, &atol, &dtol, &max_it));
293:     PetscCall(KSPSetTolerances(mctx->J0ksp, rtol, atol, dtol, max_it));
294:   }
295:   mctx->shift = bctx->shift;

297:   PetscCall(MatLMVMAllocate(*mat, bctx->Xprev, bctx->Fprev));
298:   if (op == MAT_COPY_VALUES) PetscCall(MatCopy(B, *mat, SAME_NONZERO_PATTERN));
299:   PetscFunctionReturn(PETSC_SUCCESS);
300: }

302: static PetscErrorCode MatShift_LMVM(Mat B, PetscScalar a)
303: {
304:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

306:   PetscFunctionBegin;
307:   lmvm->shift += PetscRealPart(a);
308:   PetscFunctionReturn(PETSC_SUCCESS);
309: }

311: PetscErrorCode MatView_LMVM(Mat B, PetscViewer pv)
312: {
313:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
314:   PetscBool isascii;
315:   MatType   type;

317:   PetscFunctionBegin;
318:   PetscCall(PetscObjectTypeCompare((PetscObject)pv, PETSCVIEWERASCII, &isascii));
319:   if (isascii) {
320:     PetscBool         is_exact;
321:     PetscViewerFormat format;

323:     PetscCall(MatGetType(B, &type));
324:     PetscCall(PetscViewerASCIIPrintf(pv, "Max. storage: %" PetscInt_FMT "\n", lmvm->m));
325:     PetscCall(PetscViewerASCIIPrintf(pv, "Used storage: %" PetscInt_FMT "\n", PetscMin(lmvm->k, lmvm->m)));
326:     PetscCall(PetscViewerASCIIPrintf(pv, "Number of updates: %" PetscInt_FMT "\n", lmvm->nupdates));
327:     PetscCall(PetscViewerASCIIPrintf(pv, "Number of rejects: %" PetscInt_FMT "\n", lmvm->nrejects));
328:     PetscCall(PetscViewerASCIIPrintf(pv, "Number of resets: %" PetscInt_FMT "\n", lmvm->nresets));
329:     PetscCall(PetscViewerGetFormat(pv, &format));
330:     if (format == PETSC_VIEWER_ASCII_INFO_DETAIL) {
331:       PetscCall(PetscViewerASCIIPrintf(pv, "Mult algorithm: %s\n", MatLMVMMultAlgorithms[lmvm->mult_alg]));
332:       PetscCall(PetscViewerASCIIPrintf(pv, "Cache J0 products: %s\n", lmvm->do_not_cache_J0_products ? "false" : "true"));
333:       PetscCall(PetscViewerASCIIPrintf(pv, "Cache gradient products: %s\n", lmvm->cache_gradient_products ? "true" : "false"));
334:     }
335:     PetscCall(MatLMVMJ0KSPIsExact(B, &is_exact));
336:     if (is_exact) {
337:       PetscBool is_scalar;

339:       PetscCall(PetscObjectTypeCompare((PetscObject)lmvm->J0, MATCONSTANTDIAGONAL, &is_scalar));
340:       PetscCall(PetscViewerASCIIPrintf(pv, "J0:\n"));
341:       PetscCall(PetscViewerASCIIPushTab(pv));
342:       PetscCall(PetscViewerPushFormat(pv, is_scalar ? PETSC_VIEWER_DEFAULT : PETSC_VIEWER_ASCII_INFO));
343:       PetscCall(MatView(lmvm->J0, pv));
344:       PetscCall(PetscViewerPopFormat(pv));
345:       PetscCall(PetscViewerASCIIPopTab(pv));
346:     } else {
347:       PetscCall(PetscViewerASCIIPrintf(pv, "J0 KSP:\n"));
348:       PetscCall(PetscViewerASCIIPushTab(pv));
349:       PetscCall(PetscViewerPushFormat(pv, PETSC_VIEWER_ASCII_INFO));
350:       PetscCall(KSPView(lmvm->J0ksp, pv));
351:       PetscCall(PetscViewerPopFormat(pv));
352:       PetscCall(PetscViewerASCIIPopTab(pv));
353:     }
354:   }
355:   PetscFunctionReturn(PETSC_SUCCESS);
356: }

358: PetscErrorCode MatSetFromOptions_LMVM(Mat B, PetscOptionItems PetscOptionsObject)
359: {
360:   Mat_LMVM            *lmvm     = (Mat_LMVM *)B->data;
361:   PetscBool            cache_J0 = lmvm->do_not_cache_J0_products ? PETSC_FALSE : PETSC_TRUE; // Default is false, but flipping double negative so that the command line option make sense
362:   PetscBool            set;
363:   PetscInt             hist_size = lmvm->m;
364:   MatLMVMMultAlgorithm mult_alg;

366:   PetscFunctionBegin;
367:   PetscCall(MatLMVMGetMultAlgorithm(B, &mult_alg));
368:   PetscOptionsHeadBegin(PetscOptionsObject, "Limited-memory Variable Metric matrix for approximating Jacobians");
369:   PetscCall(PetscOptionsInt("-mat_lmvm_hist_size", "number of past updates kept in memory for the approximation", "", hist_size, &hist_size, NULL));
370:   PetscCall(PetscOptionsEnum("-mat_lmvm_mult_algorithm", "Algorithm used to matrix-vector products", "", MatLMVMMultAlgorithms, (PetscEnum)mult_alg, (PetscEnum *)&mult_alg, &set));
371:   PetscCall(PetscOptionsReal("-mat_lmvm_eps", "(developer) machine zero definition", "", lmvm->eps, &lmvm->eps, NULL));
372:   PetscCall(PetscOptionsBool("-mat_lmvm_cache_J0_products", "Cache applications of the kernel J0 or its inverse", "", cache_J0, &cache_J0, NULL));
373:   PetscCall(PetscOptionsBool("-mat_lmvm_cache_gradient_products", "Cache data used to apply the inverse Hessian to a gradient vector to accelerate the quasi-Newton update", "", lmvm->cache_gradient_products, &lmvm->cache_gradient_products, NULL));
374:   PetscCall(PetscOptionsBool("-mat_lmvm_debug", "(developer) Perform internal debugging checks", "", lmvm->debug, &lmvm->debug, NULL));
375:   PetscOptionsHeadEnd();
376:   lmvm->do_not_cache_J0_products = cache_J0 ? PETSC_FALSE : PETSC_TRUE;
377:   if (hist_size != lmvm->m) PetscCall(MatLMVMSetHistorySize(B, hist_size));
378:   if (set) PetscCall(MatLMVMSetMultAlgorithm(B, mult_alg));
379:   if (lmvm->created_J0) PetscCall(MatSetFromOptions(lmvm->J0));
380:   if (lmvm->created_J0ksp) PetscCall(KSPSetFromOptions(lmvm->J0ksp));
381:   PetscFunctionReturn(PETSC_SUCCESS);
382: }

384: PetscErrorCode MatSetUp_LMVM(Mat B)
385: {
386:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

388:   PetscFunctionBegin;
389:   PetscCall(PetscLayoutSetUp(B->rmap));
390:   PetscCall(PetscLayoutSetUp(B->cmap));
391:   if (lmvm->created_J0) {
392:     PetscCall(PetscLayoutReference(B->rmap, &lmvm->J0->rmap));
393:     PetscCall(PetscLayoutReference(B->cmap, &lmvm->J0->cmap));
394:     PetscCall(MatSetUp(lmvm->J0));
395:   }
396:   PetscCall(MatLMVMAllocateVecs(B));
397:   PetscCall(MatLMVMAllocateBases(B));
398:   PetscFunctionReturn(PETSC_SUCCESS);
399: }

401: /*@
402:   MatLMVMSetMultAlgorithm - Set the algorithm used by a `MatLMVM` for products

404:   Logically collective

406:   Input Parameters:
407: + B   - a `MatLMVM` matrix
408: - alg - one of the algorithm classes (`MAT_LMVM_MULT_RECURSIVE`, `MAT_LMVM_MULT_DENSE`, `MAT_LMVM_MULT_COMPACT_DENSE`)

410:   Level: advanced

412: .seealso: [](ch_matrices), `MatLMVM`, `MatLMVMMultAlgorithm`, `MatLMVMGetMultAlgorithm()`
413: @*/
414: PetscErrorCode MatLMVMSetMultAlgorithm(Mat B, MatLMVMMultAlgorithm alg)
415: {
416:   PetscFunctionBegin;
418:   PetscTryMethod(B, "MatLMVMSetMultAlgorithm_C", (Mat, MatLMVMMultAlgorithm), (B, alg));
419:   PetscFunctionReturn(PETSC_SUCCESS);
420: }

422: static PetscErrorCode MatLMVMSetMultAlgorithm_LMVM(Mat B, MatLMVMMultAlgorithm alg)
423: {
424:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

426:   PetscFunctionBegin;
427:   lmvm->mult_alg = alg;
428:   if (lmvm->ops->setmultalgorithm) PetscCall((*lmvm->ops->setmultalgorithm)(B));
429:   PetscFunctionReturn(PETSC_SUCCESS);
430: }

432: /*@
433:   MatLMVMGetMultAlgorithm - Get the algorithm used by a `MatLMVM` for products

435:   Not collective

437:   Input Parameter:
438: . B - a `MatLMVM` matrix

440:   Output Parameter:
441: . alg - one of the algorithm classes (`MAT_LMVM_MULT_RECURSIVE`, `MAT_LMVM_MULT_DENSE`, `MAT_LMVM_MULT_COMPACT_DENSE`)

443:   Level: advanced

445: .seealso: [](ch_matrices), `MatLMVM`, `MatLMVMMultAlgorithm`, `MatLMVMSetMultAlgorithm()`
446: @*/
447: PetscErrorCode MatLMVMGetMultAlgorithm(Mat B, MatLMVMMultAlgorithm *alg)
448: {
449:   PetscFunctionBegin;
451:   PetscAssertPointer(alg, 2);
452:   PetscUseMethod(B, "MatLMVMGetMultAlgorithm_C", (Mat, MatLMVMMultAlgorithm *), (B, alg));
453:   PetscFunctionReturn(PETSC_SUCCESS);
454: }

456: static PetscErrorCode MatLMVMGetMultAlgorithm_LMVM(Mat B, MatLMVMMultAlgorithm *alg)
457: {
458:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

460:   PetscFunctionBegin;
461:   *alg = lmvm->mult_alg;
462:   PetscFunctionReturn(PETSC_SUCCESS);
463: }

465: PetscErrorCode MatDestroy_LMVM(Mat B)
466: {
467:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

469:   PetscFunctionBegin;
470:   PetscCall(MatReset_LMVM(B, MAT_LMVM_RESET_ALL));
471:   PetscCall(KSPDestroy(&lmvm->J0ksp));
472:   PetscCall(MatDestroy(&lmvm->J0));
473:   PetscCall(PetscFree(B->data));
474:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMGetLastUpdate_C", NULL));
475:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSetMultAlgorithm_C", NULL));
476:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMGetMultAlgorithm_C", NULL));
477:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetOptionsPrefix_C", NULL));
478:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatAppendOptionsPrefix_C", NULL));
479:   PetscFunctionReturn(PETSC_SUCCESS);
480: }

482: /*@
483:   MatLMVMGetLastUpdate - Get the last vectors passed to `MatLMVMUpdate()`

485:   Not collective

487:   Input Parameter:
488: . B - a `MatLMVM` matrix

490:   Output Parameters:
491: + x_prev - the last solution vector
492: - f_prev - the last function vector

494:   Level: intermediate

496: .seealso: [](ch_matrices), `MatLMVM`, `MatLMVMUpdate()`
497: @*/
498: PetscErrorCode MatLMVMGetLastUpdate(Mat B, Vec *x_prev, Vec *f_prev)
499: {
500:   PetscFunctionBegin;
502:   PetscTryMethod(B, "MatLMVMGetLastUpdate_C", (Mat, Vec *, Vec *), (B, x_prev, f_prev));
503:   PetscFunctionReturn(PETSC_SUCCESS);
504: }

506: static PetscErrorCode MatLMVMGetLastUpdate_LMVM(Mat B, Vec *x_prev, Vec *f_prev)
507: {
508:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

510:   PetscFunctionBegin;
511:   if (x_prev) *x_prev = (lmvm->prev_set) ? lmvm->Xprev : NULL;
512:   if (f_prev) *f_prev = (lmvm->prev_set) ? lmvm->Fprev : NULL;
513:   PetscFunctionReturn(PETSC_SUCCESS);
514: }

516: /* in both MatSetOptionsPrefix() and MatAppendOptionsPrefix(), this is called after
517:    the prefix of B has been changed, so we just query the prefix of B rather than
518:    using the passed prefix */
519: static PetscErrorCode MatSetOptionsPrefix_LMVM(Mat B, const char unused[])
520: {
521:   Mat_LMVM   *lmvm = (Mat_LMVM *)B->data;
522:   const char *prefix;

524:   PetscFunctionBegin;
525:   PetscCall(MatGetOptionsPrefix(B, &prefix));
526:   if (lmvm->created_J0) {
527:     PetscCall(MatSetOptionsPrefix(lmvm->J0, prefix));
528:     PetscCall(MatAppendOptionsPrefix(lmvm->J0, "mat_lmvm_J0_"));
529:   }
530:   if (lmvm->created_J0ksp) {
531:     PetscCall(KSPSetOptionsPrefix(lmvm->J0ksp, prefix));
532:     PetscCall(KSPAppendOptionsPrefix(lmvm->J0ksp, "mat_lmvm_J0_"));
533:   }
534:   PetscFunctionReturn(PETSC_SUCCESS);
535: }

537: /*MC
538:    MATLMVM - MATLMVM = "lmvm" - A matrix type used for Limited-Memory Variable Metric (LMVM) matrices.

540:    Level: intermediate

542:    Developer notes:
543:    Improve this manual page as well as many others in the MATLMVM family.

545: .seealso: [](sec_matlmvm), `Mat`
546: M*/
547: PetscErrorCode MatCreate_LMVM(Mat B)
548: {
549:   Mat_LMVM *lmvm;

551:   PetscFunctionBegin;
552:   PetscCall(MatLMVMPackageInitialize());
553:   PetscCall(PetscNew(&lmvm));
554:   B->data = (void *)lmvm;

556:   lmvm->m   = 5;
557:   lmvm->eps = PetscPowReal(PETSC_MACHINE_EPSILON, 2.0 / 3.0);

559:   B->ops->destroy                = MatDestroy_LMVM;
560:   B->ops->setfromoptions         = MatSetFromOptions_LMVM;
561:   B->ops->view                   = MatView_LMVM;
562:   B->ops->setup                  = MatSetUp_LMVM;
563:   B->ops->shift                  = MatShift_LMVM;
564:   B->ops->duplicate              = MatDuplicate_LMVM;
565:   B->ops->mult                   = MatMult_LMVM;
566:   B->ops->multhermitiantranspose = MatMultHermitianTranspose_LMVM;
567:   B->ops->multadd                = MatMultAdd_LMVM;
568:   B->ops->copy                   = MatCopy_LMVM;
569:   B->ops->solve                  = MatSolve_LMVM;
570:   B->ops->solvetranspose         = MatSolveTranspose_LMVM;
571:   if (!PetscDefined(USE_COMPLEX)) B->ops->multtranspose = MatMultHermitianTranspose_LMVM;

573:   /*
574:     There is no assembly phase, Mat_LMVM relies on B->preallocated to ensure that
575:     necessary setup happens in MatSetUp(), which is called in MatCheckPreallocated()
576:     in all major operations (MatLMVMUpdate(), MatMult(), MatSolve(), etc.)
577:    */
578:   B->assembled = PETSC_TRUE;

580:   lmvm->ops->update = MatUpdate_LMVM;

582:   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVM));
583:   // J0 should be present at all times, calling ClearJ0() here initializes it to the identity
584:   PetscCall(MatLMVMClearJ0(B));

586:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMGetLastUpdate_C", MatLMVMGetLastUpdate_LMVM));
587:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSetMultAlgorithm_C", MatLMVMSetMultAlgorithm_LMVM));
588:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMGetMultAlgorithm_C", MatLMVMGetMultAlgorithm_LMVM));
589:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetOptionsPrefix_C", MatSetOptionsPrefix_LMVM));
590:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatAppendOptionsPrefix_C", MatSetOptionsPrefix_LMVM));
591:   PetscFunctionReturn(PETSC_SUCCESS);
592: }