Actual source code: sr1.c

  1: #include <../src/ksp/ksp/utils/lmvm/lmvm.h>

  3: /*
  4:   Limited-memory Symmetric-Rank-1 method for approximating both
  5:   the forward product and inverse application of a Jacobian.
  6: */

  8: // bases needed by SR1 algorithms beyond those in Mat_LMVM
  9: enum {
 10:   SR1_BASIS_Y_MINUS_BKS = 0, // Y_k - B_k S_k for recursive algorithms
 11:   SR1_BASIS_S_MINUS_HKY = 1, // dual to the above, S_k - H_k Y_k
 12:   SR1_BASIS_COUNT
 13: };

 15: typedef PetscInt SR1BasisType;

 17: // products needed by SR1 agorithms beyond those in Mat_LMVM
 18: enum {
 19:   SR1_PRODUCTS_YTS_MINUS_STB0S = 0, // stores and factors symm(triu((Y - B_0 S)^T S)) for compact algorithms
 20:   SR1_PRODUCTS_STY_MINUS_YTH0Y = 1, // dual to the above, stores and factors symm(triu((S - H_0 Y)^T Y))
 21:   SR1_PRODUCTS_YTS_MINUS_STBKS = 2, // diagonal (Y_k - B_k S_k)^T S_k values for recursive algorthms
 22:   SR1_PRODUCTS_STY_MINUS_YTHKY = 3, // dual to the above, diagonal (S_k - H_k Y_k)^T Y_k
 23:   SR1_PRODUCTS_COUNT
 24: };

 26: typedef PetscInt SR1ProductsType;

 28: typedef struct {
 29:   LMBasis    basis[SR1_BASIS_COUNT];
 30:   LMProducts products[SR1_PRODUCTS_COUNT];
 31:   Vec        StFprev, SmH0YtFprev;
 32: } Mat_LSR1;

 34: /* The SR1 kernel can be written as

 36:      B_{k+1} = B_k + (y_k - B_k s_k) ((y_k - B_k s_k)^T s_k)^{-1} (y_k - B_k s_k)^T

 38:    this unrolls to a rank-m update

 40:      B_{k+1} = B_0 + \sum_{i = k-m+1}^k (y_i - B_i s_i) ((y_i - B_i s_i)^T s_i)^{-1} (y_i - B_i s_i)^T

 42:    This inner kernel assumes the (y_i - B_i s_i) vectors and the ((y_i - B_i s_i)^T s_i) products have been computed
 43:  */

 45: static PetscErrorCode SR1Kernel_Recursive_Inner(Mat B, MatLMVMMode mode, PetscInt oldest, PetscInt next, Vec X, Vec BX)
 46: {
 47:   Mat_LMVM       *lmvm              = (Mat_LMVM *)B->data;
 48:   Mat_LSR1       *lsr1              = (Mat_LSR1 *)lmvm->ctx;
 49:   SR1BasisType    Y_minus_BkS_t     = LMVMModeMap(SR1_BASIS_Y_MINUS_BKS, mode);
 50:   SR1ProductsType YtS_minus_StBkS_t = LMVMModeMap(SR1_PRODUCTS_STY_MINUS_YTHKY, mode);
 51:   LMBasis         Y_minus_BkS       = lsr1->basis[Y_minus_BkS_t];
 52:   LMProducts      YtS_minus_StBkS   = lsr1->products[YtS_minus_StBkS_t];
 53:   Vec             YmBkStX;

 55:   PetscFunctionBegin;
 56:   PetscCall(MatLMVMGetWorkRow(B, &YmBkStX));
 57:   PetscCall(LMBasisGEMVH(Y_minus_BkS, oldest, next, 1.0, X, 0.0, YmBkStX));
 58:   PetscCall(LMProductsSolve(YtS_minus_StBkS, oldest, next, YmBkStX, YmBkStX, /* ^H */ PETSC_FALSE));
 59:   PetscCall(LMBasisGEMV(Y_minus_BkS, oldest, next, 1.0, YmBkStX, 1.0, BX));
 60:   PetscCall(MatLMVMRestoreWorkRow(B, &YmBkStX));
 61:   PetscFunctionReturn(PETSC_SUCCESS);
 62: }

 64: /* Recursively compute the (y_i - B_i s_i) vectors and ((y_i - B_i s_i)^T s_i) products */

 66: static PetscErrorCode SR1RecursiveBasisUpdate(Mat B, MatLMVMMode mode)
 67: {
 68:   Mat_LMVM        *lmvm              = (Mat_LMVM *)B->data;
 69:   Mat_LSR1        *lsr1              = (Mat_LSR1 *)lmvm->ctx;
 70:   MatLMVMBasisType B0S_t             = LMVMModeMap(LMBASIS_B0S, mode);
 71:   MatLMVMBasisType S_t               = LMVMModeMap(LMBASIS_S, mode);
 72:   MatLMVMBasisType Y_t               = LMVMModeMap(LMBASIS_Y, mode);
 73:   SR1BasisType     Y_minus_BkS_t     = LMVMModeMap(SR1_BASIS_Y_MINUS_BKS, mode);
 74:   SR1ProductsType  YtS_minus_StBkS_t = LMVMModeMap(SR1_PRODUCTS_STY_MINUS_YTHKY, mode);
 75:   LMBasis          Y_minus_BkS;
 76:   LMProducts       YtS_minus_StBkS;
 77:   PetscInt         oldest, next;
 78:   PetscInt         products_oldest;
 79:   LMBasis          S, Y;
 80:   PetscInt         start;

 82:   PetscFunctionBegin;
 83:   if (!lsr1->basis[Y_minus_BkS_t]) PetscCall(LMBasisCreate(mode == MATLMVM_MODE_PRIMAL ? lmvm->Fprev : lmvm->Xprev, lmvm->m, &lsr1->basis[Y_minus_BkS_t]));
 84:   Y_minus_BkS = lsr1->basis[Y_minus_BkS_t];
 85:   if (!lsr1->products[YtS_minus_StBkS_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_DIAGONAL, &lsr1->products[YtS_minus_StBkS_t]));
 86:   YtS_minus_StBkS = lsr1->products[YtS_minus_StBkS_t];
 87:   PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
 88:   PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
 89:   PetscCall(MatLMVMGetRange(B, &oldest, &next));
 90:   // invalidate computed values if J0 has changed
 91:   PetscCall(LMProductsPrepare(YtS_minus_StBkS, lmvm->J0, oldest, next));
 92:   products_oldest = PetscMax(0, YtS_minus_StBkS->k - lmvm->m);
 93:   if (oldest > products_oldest) {
 94:     // recursion is starting from a different starting index, it must be recomputed
 95:     YtS_minus_StBkS->k = oldest;
 96:   }
 97:   Y_minus_BkS->k = start = YtS_minus_StBkS->k;
 98:   // recompute each column in Y_minus_BkS in order
 99:   for (PetscInt j = start; j < next; j++) {
100:     Vec         s_j, B0s_j, p_j, y_j;
101:     PetscScalar alpha, ymbksts;

103:     PetscCall(LMBasisGetWorkVec(Y_minus_BkS, &p_j));

105:     // p_j starts as B_0 * s_j
106:     PetscCall(MatLMVMBasisGetVecRead(B, B0S_t, j, &B0s_j, &alpha));
107:     PetscCall(VecAXPBY(p_j, alpha, 0.0, B0s_j));
108:     PetscCall(MatLMVMBasisRestoreVecRead(B, B0S_t, j, &B0s_j, &alpha));

110:     // Use the matmult kernel to compute p_j = B_j * p_j
111:     PetscCall(LMBasisGetVecRead(S, j, &s_j));
112:     // if j == oldest p_j is already correct
113:     if (j > oldest) PetscCall(SR1Kernel_Recursive_Inner(B, mode, oldest, j, s_j, p_j));
114:     PetscCall(LMBasisGetVecRead(Y, j, &y_j));
115:     PetscCall(VecAYPX(p_j, -1.0, y_j));
116:     PetscCall(VecDot(s_j, p_j, &ymbksts));
117:     PetscCall(LMProductsInsertNextDiagonalValue(YtS_minus_StBkS, j, ymbksts));
118:     PetscCall(LMBasisRestoreVecRead(S, j, &s_j));
119:     PetscCall(LMBasisRestoreVecRead(Y, j, &y_j));
120:     PetscCall(LMBasisSetNextVec(Y_minus_BkS, p_j));
121:     PetscCall(LMBasisRestoreWorkVec(Y_minus_BkS, &p_j));
122:   }
123:   PetscFunctionReturn(PETSC_SUCCESS);
124: }

126: static PetscErrorCode SR1Kernel_Recursive(Mat B, MatLMVMMode mode, Vec X, Vec BX)
127: {
128:   PetscInt oldest, next;

130:   PetscFunctionBegin;
131:   PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, BX));
132:   PetscCall(MatLMVMGetRange(B, &oldest, &next));
133:   if (next > oldest) {
134:     PetscCall(SR1RecursiveBasisUpdate(B, mode));
135:     PetscCall(SR1Kernel_Recursive_Inner(B, mode, oldest, next, X, BX));
136:   }
137:   PetscFunctionReturn(PETSC_SUCCESS);
138: }

140: /* The SR1 kernel can be written as (See Byrd, Schnabel & Nocedal 1994)

142:      B_{k+1} = B_0 + (Y - B_0 S) (diag(S^T Y) + stril(S^T Y) + stril(S^T Y)^T - S^T B_0 S)^{-1} (Y - B_0 S)^T
143:                                  \___________________________ ___________________________/
144:                                                              V
145:                                                              M

147:    M is symmetric indefinite (stril is the strictly lower triangular part)

149:    M can be computed by computed triu((Y - B_0 S)^T S) and filling in the lower triangle
150:  */

152: static PetscErrorCode SR1CompactProductsUpdate(Mat B, MatLMVMMode mode)
153: {
154:   Mat_LMVM        *lmvm              = (Mat_LMVM *)B->data;
155:   Mat_LSR1        *lsr1              = (Mat_LSR1 *)lmvm->ctx;
156:   MatLMVMBasisType S_t               = LMVMModeMap(LMBASIS_S, mode);
157:   MatLMVMBasisType YmB0S_t           = LMVMModeMap(LMBASIS_Y_MINUS_B0S, mode);
158:   SR1ProductsType  YtS_minus_StB0S_t = LMVMModeMap(SR1_PRODUCTS_YTS_MINUS_STB0S, mode);
159:   LMProducts       YtS_minus_StB0S;
160:   Mat              local;
161:   PetscInt         oldest, next, k;
162:   PetscBool        local_is_nonempty;

164:   PetscFunctionBegin;
165:   if (!lsr1->products[YtS_minus_StB0S_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_FULL, &lsr1->products[YtS_minus_StB0S_t]));
166:   YtS_minus_StB0S = lsr1->products[YtS_minus_StB0S_t];
167:   PetscCall(MatLMVMGetRange(B, &oldest, &next));
168:   PetscCall(LMProductsPrepare(YtS_minus_StB0S, lmvm->J0, oldest, next));
169:   PetscCall(LMProductsGetLocalMatrix(YtS_minus_StB0S, &local, &k, &local_is_nonempty));
170:   if (YtS_minus_StB0S->k < next) {
171:     // copy to factor in place
172:     LMProducts YmB0StS;
173:     Mat        ymb0sts_local;

175:     PetscCall(PetscCitationsRegister(ByrdNocedalSchnabelCitation, &ByrdNocedalSchnabelCite));
176:     YtS_minus_StB0S->k = next;
177:     PetscCall(MatLMVMGetUpdatedProducts(B, YmB0S_t, S_t, LMBLOCK_UPPER_TRIANGLE, &YmB0StS));
178:     PetscCall(LMProductsGetLocalMatrix(YmB0StS, &ymb0sts_local, NULL, NULL));
179:     if (local_is_nonempty) {
180:       PetscErrorCode ierr;

182:       PetscCall(MatSetUnfactored(local));
183:       PetscCall(MatCopy(ymb0sts_local, local, SAME_NONZERO_PATTERN));
184:       PetscCall(LMProductsMakeHermitian(local, oldest, next));
185:       PetscCall(LMProductsOnesOnUnusedDiagonal(local, oldest, next));
186:       PetscCall(MatSetOption(local, MAT_HERMITIAN, PETSC_TRUE));
187:       // Set not spd so that "Cholesky" factorization is actually the symmetric indefinite Bunch Kaufman factorization
188:       PetscCall(MatSetOption(local, MAT_SPD, PETSC_FALSE));

190:       PetscCall(PetscPushErrorHandler(PetscReturnErrorHandler, NULL));
191:       ierr = MatCholeskyFactor(local, NULL, NULL);
192:       PetscCall(PetscPopErrorHandler());
193:       PetscCheck(ierr == PETSC_SUCCESS || ierr == PETSC_ERR_SUP, PETSC_COMM_SELF, ierr, "Error in Bunch-Kaufman factorization");
194:       // cusolver does not provide Bunch Kaufman, resort to LU if it is unavailable
195:       if (ierr == PETSC_ERR_SUP) PetscCall(MatLUFactor(local, NULL, NULL, NULL));
196:     }
197:     PetscCall(LMProductsRestoreLocalMatrix(YmB0StS, &ymb0sts_local, NULL));
198:   }
199:   PetscCall(LMProductsRestoreLocalMatrix(YtS_minus_StB0S, &local, &next));
200:   PetscFunctionReturn(PETSC_SUCCESS);
201: }

203: static PetscErrorCode SR1Kernel_CompactDense(Mat B, MatLMVMMode mode, Vec X, Vec BX)
204: {
205:   PetscInt oldest, next;

207:   PetscFunctionBegin;
208:   PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, BX));
209:   PetscCall(MatLMVMGetRange(B, &oldest, &next));
210:   if (next > oldest) {
211:     Mat_LMVM        *lmvm              = (Mat_LMVM *)B->data;
212:     Mat_LSR1        *lsr1              = (Mat_LSR1 *)lmvm->ctx;
213:     MatLMVMBasisType Y_minus_B0S_t     = LMVMModeMap(LMBASIS_Y_MINUS_B0S, mode);
214:     SR1ProductsType  YtS_minus_StB0S_t = LMVMModeMap(SR1_PRODUCTS_YTS_MINUS_STB0S, mode);
215:     LMProducts       YtS_minus_StB0S;
216:     Vec              YmB0StX, v;

218:     PetscCall(SR1CompactProductsUpdate(B, mode));
219:     YtS_minus_StB0S = lsr1->products[YtS_minus_StB0S_t];
220:     PetscCall(MatLMVMGetWorkRow(B, &YmB0StX));
221:     PetscCall(MatLMVMGetWorkRow(B, &v));
222:     if (lmvm->do_not_cache_J0_products) {
223:       /* the initial (Y - B_0 S)^T x inner product can be computed as Y^T x - S^T (B_0 x)
224:          if we are not caching B_0 S products */
225:       MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
226:       MatLMVMBasisType Y_t = LMVMModeMap(LMBASIS_Y, mode);
227:       LMBasis          S, Y;

229:       PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
230:       PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
231:       PetscCall(LMBasisGEMVH(Y, oldest, next, 1.0, X, 0.0, YmB0StX));
232:       PetscCall(LMBasisGEMVH(S, oldest, next, -1.0, BX, 1.0, YmB0StX));
233:     } else PetscCall(MatLMVMBasisGEMVH(B, Y_minus_B0S_t, oldest, next, 1.0, X, 0.0, YmB0StX));
234:     PetscCall(LMProductsSolve(YtS_minus_StB0S, oldest, next, YmB0StX, v, PETSC_FALSE));
235:     PetscCall(MatLMVMBasisGEMV(B, Y_minus_B0S_t, oldest, next, 1.0, v, 1.0, BX));
236:     PetscCall(MatLMVMRestoreWorkRow(B, &v));
237:     PetscCall(MatLMVMRestoreWorkRow(B, &YmB0StX));
238:   }
239:   PetscFunctionReturn(PETSC_SUCCESS);
240: }

242: static PetscErrorCode MatMult_LMVMSR1_CompactDense(Mat B, Vec X, Vec BX)
243: {
244:   PetscFunctionBegin;
245:   PetscCall(SR1Kernel_CompactDense(B, MATLMVM_MODE_PRIMAL, X, BX));
246:   PetscFunctionReturn(PETSC_SUCCESS);
247: }

249: static PetscErrorCode MatSolve_LMVMSR1_CompactDense(Mat B, Vec X, Vec BX)
250: {
251:   PetscFunctionBegin;
252:   PetscCall(SR1Kernel_CompactDense(B, MATLMVM_MODE_DUAL, X, BX));
253:   PetscFunctionReturn(PETSC_SUCCESS);
254: }

256: static PetscErrorCode MatMult_LMVMSR1_Recursive(Mat B, Vec X, Vec Z)
257: {
258:   PetscFunctionBegin;
259:   PetscCall(SR1Kernel_Recursive(B, MATLMVM_MODE_PRIMAL, X, Z));
260:   PetscFunctionReturn(PETSC_SUCCESS);
261: }

263: static PetscErrorCode MatSolve_LMVMSR1_Recursive(Mat B, Vec F, Vec dX)
264: {
265:   PetscFunctionBegin;
266:   PetscCall(SR1Kernel_Recursive(B, MATLMVM_MODE_DUAL, F, dX));
267:   PetscFunctionReturn(PETSC_SUCCESS);
268: }

270: static PetscErrorCode MatUpdate_LMVMSR1(Mat B, Vec X, Vec F)
271: {
272:   Mat_LMVM *lmvm          = (Mat_LMVM *)B->data;
273:   Mat_LSR1 *sr1           = (Mat_LSR1 *)lmvm->ctx;
274:   PetscBool cache_SmH0YtF = (lmvm->mult_alg != MAT_LMVM_MULT_RECURSIVE && !lmvm->do_not_cache_J0_products) ? lmvm->cache_gradient_products : PETSC_FALSE;

276:   PetscFunctionBegin;
277:   if (!lmvm->m) PetscFunctionReturn(PETSC_SUCCESS);
278:   if (lmvm->prev_set) {
279:     PetscReal   snorm, pnorm;
280:     PetscScalar sktw;
281:     Vec         work;
282:     Vec         Fprev_old       = NULL;
283:     Vec         SmH0YtFprev_old = NULL;
284:     LMProducts  SmH0YtY         = NULL;
285:     PetscInt    oldest, next;
286:     LMBasis     SmH0Y = NULL;
287:     LMBasis     Y;

289:     PetscCall(MatLMVMGetRange(B, &oldest, &next));
290:     if (cache_SmH0YtF) {
291:       PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_S_MINUS_H0Y, &SmH0Y, NULL, NULL));
292:       if (!sr1->SmH0YtFprev) PetscCall(LMBasisCreateRow(SmH0Y, &sr1->SmH0YtFprev));
293:       PetscCall(LMBasisGetWorkVec(SmH0Y, &Fprev_old));
294:       PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_S_MINUS_H0Y, LMBASIS_Y, LMBLOCK_UPPER_TRIANGLE, &SmH0YtY));
295:       PetscCall(LMProductsGetNextColumn(SmH0YtY, &SmH0YtFprev_old));
296:       PetscCall(VecCopy(lmvm->Fprev, Fprev_old));
297:       if (sr1->SmH0YtFprev == SmH0Y->cached_product) {
298:         PetscCall(VecCopy(sr1->SmH0YtFprev, SmH0YtFprev_old));
299:       } else {
300:         if (next > oldest) {
301:           // need to recalculate
302:           PetscCall(LMBasisGEMVH(SmH0Y, oldest, next, 1.0, Fprev_old, 0.0, SmH0YtFprev_old));
303:         } else {
304:           PetscCall(VecZeroEntries(SmH0YtFprev_old));
305:         }
306:       }
307:     }

309:     /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
310:     PetscCall(VecAYPX(lmvm->Xprev, -1.0, X));
311:     PetscCall(VecAYPX(lmvm->Fprev, -1.0, F));

313:     /* See if the updates can be accepted
314:        NOTE: This tests abs(S[k]^T (Y[k] - B_k*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_k*S[k])

316:        Note that this test is flawed because this is a **limited memory** SR1 method: we are testing

318:          abs(S[k]^T (Y[k] - B_{k,m}*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_{k,m}*S[k])

320:        when the oldest pair of vectors in the definition of B_{k,m}, (s_{k-m}, y_{k-m}), will be dropped if we add a new
321:        pair.  To really ensure that B_{k+1} = B_{k+1,m} is nonsingular, you need to test

323:          abs(S[k]^T (Y[k] - B_{k,m-1}*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_{k,m-1}*S[k])

325:        But the product B_{k,m-1}*S[k] is not readily computable (see e.g. Lu, Xuehua, "A study of the limited memory SR1
326:        method in practice", 1996).
327:      */
328:     PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_Y, &Y, NULL, NULL));
329:     PetscCall(LMBasisGetWorkVec(Y, &work));
330:     PetscCall(MatMult(B, lmvm->Xprev, work));
331:     PetscCall(VecAYPX(work, -1.0, lmvm->Fprev));
332:     PetscCall(VecDot(lmvm->Xprev, work, &sktw));
333:     PetscCall(VecNorm(lmvm->Xprev, NORM_2, &snorm));
334:     PetscCall(VecNorm(work, NORM_2, &pnorm));
335:     PetscCall(LMBasisRestoreWorkVec(Y, &work));
336:     if (PetscAbsReal(PetscRealPart(sktw)) >= lmvm->eps * snorm * pnorm) {
337:       /* Update is good, accept it */
338:       PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
339:       if (cache_SmH0YtF) {
340:         PetscInt oldest_new, next_new;

342:         PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_S_MINUS_H0Y, &SmH0Y, NULL, NULL));
343:         PetscCall(MatLMVMGetRange(B, &oldest_new, &next_new));
344:         PetscCall(LMBasisGEMVH(SmH0Y, next, next_new, 1.0, Fprev_old, 0.0, SmH0YtFprev_old));
345:         PetscCall(LMBasisGEMVH(SmH0Y, oldest_new, next_new, 1.0, F, 0.0, sr1->SmH0YtFprev));
346:         PetscCall(LMBasisSetCachedProduct(SmH0Y, F, sr1->SmH0YtFprev));
347:         PetscCall(VecAXPBY(SmH0YtFprev_old, 1.0, -1.0, sr1->SmH0YtFprev));
348:         PetscCall(LMProductsRestoreNextColumn(SmH0YtY, &SmH0YtFprev_old));
349:       }
350:     } else {
351:       /* Update is bad, skip it */
352:       lmvm->nrejects++;
353:       if (cache_SmH0YtF) {
354:         // we still need to update the cached product
355:         PetscCall(LMBasisGEMVH(SmH0Y, oldest, next, 1.0, F, 0.0, sr1->SmH0YtFprev));
356:         PetscCall(LMBasisSetCachedProduct(SmH0Y, F, sr1->SmH0YtFprev));
357:       }
358:     }
359:     if (cache_SmH0YtF) PetscCall(LMBasisRestoreWorkVec(SmH0Y, &Fprev_old));
360:   }
361:   /* Save the solution and function to be used in the next update */
362:   PetscCall(VecCopy(X, lmvm->Xprev));
363:   PetscCall(VecCopy(F, lmvm->Fprev));
364:   lmvm->prev_set = PETSC_TRUE;
365:   PetscFunctionReturn(PETSC_SUCCESS);
366: }

368: static PetscErrorCode MatReset_LMVMSR1(Mat B, MatLMVMResetMode mode)
369: {
370:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
371:   Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;

373:   PetscFunctionBegin;
374:   if (MatLMVMResetClearsBases(mode)) {
375:     for (PetscInt i = 0; i < SR1_BASIS_COUNT; i++) PetscCall(LMBasisDestroy(&lsr1->basis[i]));
376:     for (PetscInt i = 0; i < SR1_PRODUCTS_COUNT; i++) PetscCall(LMProductsDestroy(&lsr1->products[i]));
377:     PetscCall(VecDestroy(&lsr1->StFprev));
378:     PetscCall(VecDestroy(&lsr1->SmH0YtFprev));
379:   } else {
380:     for (PetscInt i = 0; i < SR1_BASIS_COUNT; i++) PetscCall(LMBasisReset(lsr1->basis[i]));
381:     for (PetscInt i = 0; i < SR1_PRODUCTS_COUNT; i++) PetscCall(LMProductsReset(lsr1->products[i]));
382:     if (lsr1->StFprev) PetscCall(VecZeroEntries(lsr1->StFprev));
383:     if (lsr1->SmH0YtFprev) PetscCall(VecZeroEntries(lsr1->SmH0YtFprev));
384:   }
385:   PetscFunctionReturn(PETSC_SUCCESS);
386: }

388: static PetscErrorCode MatDestroy_LMVMSR1(Mat B)
389: {
390:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

392:   PetscFunctionBegin;
393:   PetscCall(MatReset_LMVMSR1(B, MAT_LMVM_RESET_ALL));
394:   PetscCall(PetscFree(lmvm->ctx));
395:   PetscCall(MatDestroy_LMVM(B));
396:   PetscFunctionReturn(PETSC_SUCCESS);
397: }

399: static PetscErrorCode MatLMVMSetMultAlgorithm_SR1(Mat B)
400: {
401:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

403:   PetscFunctionBegin;
404:   switch (lmvm->mult_alg) {
405:   case MAT_LMVM_MULT_RECURSIVE:
406:     lmvm->ops->mult  = MatMult_LMVMSR1_Recursive;
407:     lmvm->ops->solve = MatSolve_LMVMSR1_Recursive;
408:     break;
409:   case MAT_LMVM_MULT_DENSE:
410:   case MAT_LMVM_MULT_COMPACT_DENSE:
411:     lmvm->ops->mult  = MatMult_LMVMSR1_CompactDense;
412:     lmvm->ops->solve = MatSolve_LMVMSR1_CompactDense;
413:     break;
414:   }
415:   lmvm->ops->multht  = lmvm->ops->mult;
416:   lmvm->ops->solveht = lmvm->ops->solve;
417:   PetscFunctionReturn(PETSC_SUCCESS);
418: }

420: PetscErrorCode MatCreate_LMVMSR1(Mat B)
421: {
422:   Mat_LMVM *lmvm;
423:   Mat_LSR1 *lsr1;

425:   PetscFunctionBegin;
426:   PetscCall(MatCreate_LMVM(B));
427:   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMSR1));
428:   PetscCall(MatSetOption(B, MAT_HERMITIAN, PETSC_TRUE));
429:   B->ops->destroy = MatDestroy_LMVMSR1;

431:   lmvm                          = (Mat_LMVM *)B->data;
432:   lmvm->ops->reset              = MatReset_LMVMSR1;
433:   lmvm->ops->update             = MatUpdate_LMVMSR1;
434:   lmvm->ops->setmultalgorithm   = MatLMVMSetMultAlgorithm_SR1;
435:   lmvm->cache_gradient_products = PETSC_TRUE;
436:   PetscCall(MatLMVMSetMultAlgorithm_SR1(B));
437:   PetscCall(PetscNew(&lsr1));
438:   lmvm->ctx = (void *)lsr1;
439:   PetscFunctionReturn(PETSC_SUCCESS);
440: }

442: /*@
443:   MatCreateLMVMSR1 - Creates a limited-memory Symmetric-Rank-1 approximation
444:   matrix used for a Jacobian. L-SR1 is symmetric by construction, but is not
445:   guaranteed to be positive-definite.

447:   To use the L-SR1 matrix with other vector types, the matrix must be
448:   created using `MatCreate()` and `MatSetType()`, followed by `MatLMVMAllocate()`.
449:   This ensures that the internal storage and work vectors are duplicated from the
450:   correct type of vector.

452:   Collective

454:   Input Parameters:
455: + comm - MPI communicator
456: . n    - number of local rows for storage vectors
457: - N    - global size of the storage vectors

459:   Output Parameter:
460: . B - the matrix

462:   Options Database Keys:
463: + -mat_lmvm_hist_size         - the number of history vectors to keep
464: . -mat_lmvm_mult_algorithm    - the algorithm to use for multiplication (recursive, dense, compact_dense)
465: . -mat_lmvm_cache_J0_products - whether products between the base Jacobian J0 and history vectors should be cached or recomputed
466: . -mat_lmvm_eps               - (developer) numerical zero tolerance for testing when an update should be skipped
467: - -mat_lmvm_debug             - (developer) perform internal debugging checks

469:   Level: intermediate

471:   Note:
472:   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`
473:   paradigm instead of this routine directly.

475: .seealso: [](ch_ksp), `MatCreate()`, `MATLMVM`, `MATLMVMSR1`, `MatCreateLMVMBFGS()`, `MatCreateLMVMDFP()`,
476:           `MatCreateLMVMBroyden()`, `MatCreateLMVMBadBroyden()`, `MatCreateLMVMSymBroyden()`
477: @*/
478: PetscErrorCode MatCreateLMVMSR1(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
479: {
480:   PetscFunctionBegin;
481:   PetscCall(KSPInitializePackage());
482:   PetscCall(MatCreate(comm, B));
483:   PetscCall(MatSetSizes(*B, n, n, N, N));
484:   PetscCall(MatSetType(*B, MATLMVMSR1));
485:   PetscCall(MatSetUp(*B));
486:   PetscFunctionReturn(PETSC_SUCCESS);
487: }