Actual source code: sr1.c
1: #include <../src/ksp/ksp/utils/lmvm/lmvm.h>
3: /*
4: Limited-memory Symmetric-Rank-1 method for approximating both
5: the forward product and inverse application of a Jacobian.
6: */
8: // bases needed by SR1 algorithms beyond those in Mat_LMVM
9: enum {
10: SR1_BASIS_Y_MINUS_BKS = 0, // Y_k - B_k S_k for recursive algorithms
11: SR1_BASIS_S_MINUS_HKY = 1, // dual to the above, S_k - H_k Y_k
12: SR1_BASIS_COUNT
13: };
15: typedef PetscInt SR1BasisType;
17: // products needed by SR1 agorithms beyond those in Mat_LMVM
18: enum {
19: SR1_PRODUCTS_YTS_MINUS_STB0S = 0, // stores and factors symm(triu((Y - B_0 S)^T S)) for compact algorithms
20: SR1_PRODUCTS_STY_MINUS_YTH0Y = 1, // dual to the above, stores and factors symm(triu((S - H_0 Y)^T Y))
21: SR1_PRODUCTS_YTS_MINUS_STBKS = 2, // diagonal (Y_k - B_k S_k)^T S_k values for recursive algorthms
22: SR1_PRODUCTS_STY_MINUS_YTHKY = 3, // dual to the above, diagonal (S_k - H_k Y_k)^T Y_k
23: SR1_PRODUCTS_COUNT
24: };
26: typedef PetscInt SR1ProductsType;
28: typedef struct {
29: LMBasis basis[SR1_BASIS_COUNT];
30: LMProducts products[SR1_PRODUCTS_COUNT];
31: Vec StFprev, SmH0YtFprev;
32: } Mat_LSR1;
34: /* The SR1 kernel can be written as
36: B_{k+1} = B_k + (y_k - B_k s_k) ((y_k - B_k s_k)^T s_k)^{-1} (y_k - B_k s_k)^T
38: this unrolls to a rank-m update
40: B_{k+1} = B_0 + \sum_{i = k-m+1}^k (y_i - B_i s_i) ((y_i - B_i s_i)^T s_i)^{-1} (y_i - B_i s_i)^T
42: This inner kernel assumes the (y_i - B_i s_i) vectors and the ((y_i - B_i s_i)^T s_i) products have been computed
43: */
45: static PetscErrorCode SR1Kernel_Recursive_Inner(Mat B, MatLMVMMode mode, PetscInt oldest, PetscInt next, Vec X, Vec BX)
46: {
47: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
48: Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
49: SR1BasisType Y_minus_BkS_t = LMVMModeMap(SR1_BASIS_Y_MINUS_BKS, mode);
50: SR1ProductsType YtS_minus_StBkS_t = LMVMModeMap(SR1_PRODUCTS_STY_MINUS_YTHKY, mode);
51: LMBasis Y_minus_BkS = lsr1->basis[Y_minus_BkS_t];
52: LMProducts YtS_minus_StBkS = lsr1->products[YtS_minus_StBkS_t];
53: Vec YmBkStX;
55: PetscFunctionBegin;
56: PetscCall(MatLMVMGetWorkRow(B, &YmBkStX));
57: PetscCall(LMBasisGEMVH(Y_minus_BkS, oldest, next, 1.0, X, 0.0, YmBkStX));
58: PetscCall(LMProductsSolve(YtS_minus_StBkS, oldest, next, YmBkStX, YmBkStX, /* ^H */ PETSC_FALSE));
59: PetscCall(LMBasisGEMV(Y_minus_BkS, oldest, next, 1.0, YmBkStX, 1.0, BX));
60: PetscCall(MatLMVMRestoreWorkRow(B, &YmBkStX));
61: PetscFunctionReturn(PETSC_SUCCESS);
62: }
64: /* Recursively compute the (y_i - B_i s_i) vectors and ((y_i - B_i s_i)^T s_i) products */
66: static PetscErrorCode SR1RecursiveBasisUpdate(Mat B, MatLMVMMode mode)
67: {
68: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
69: Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
70: MatLMVMBasisType B0S_t = LMVMModeMap(LMBASIS_B0S, mode);
71: MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
72: MatLMVMBasisType Y_t = LMVMModeMap(LMBASIS_Y, mode);
73: SR1BasisType Y_minus_BkS_t = LMVMModeMap(SR1_BASIS_Y_MINUS_BKS, mode);
74: SR1ProductsType YtS_minus_StBkS_t = LMVMModeMap(SR1_PRODUCTS_STY_MINUS_YTHKY, mode);
75: LMBasis Y_minus_BkS;
76: LMProducts YtS_minus_StBkS;
77: PetscInt oldest, next;
78: PetscInt products_oldest;
79: LMBasis S, Y;
80: PetscInt start;
82: PetscFunctionBegin;
83: if (!lsr1->basis[Y_minus_BkS_t]) PetscCall(LMBasisCreate(mode == MATLMVM_MODE_PRIMAL ? lmvm->Fprev : lmvm->Xprev, lmvm->m, &lsr1->basis[Y_minus_BkS_t]));
84: Y_minus_BkS = lsr1->basis[Y_minus_BkS_t];
85: if (!lsr1->products[YtS_minus_StBkS_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_DIAGONAL, &lsr1->products[YtS_minus_StBkS_t]));
86: YtS_minus_StBkS = lsr1->products[YtS_minus_StBkS_t];
87: PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
88: PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
89: PetscCall(MatLMVMGetRange(B, &oldest, &next));
90: // invalidate computed values if J0 has changed
91: PetscCall(LMProductsPrepare(YtS_minus_StBkS, lmvm->J0, oldest, next));
92: products_oldest = PetscMax(0, YtS_minus_StBkS->k - lmvm->m);
93: if (oldest > products_oldest) {
94: // recursion is starting from a different starting index, it must be recomputed
95: YtS_minus_StBkS->k = oldest;
96: }
97: Y_minus_BkS->k = start = YtS_minus_StBkS->k;
98: // recompute each column in Y_minus_BkS in order
99: for (PetscInt j = start; j < next; j++) {
100: Vec s_j, B0s_j, p_j, y_j;
101: PetscScalar alpha, ymbksts;
103: PetscCall(LMBasisGetWorkVec(Y_minus_BkS, &p_j));
105: // p_j starts as B_0 * s_j
106: PetscCall(MatLMVMBasisGetVecRead(B, B0S_t, j, &B0s_j, &alpha));
107: PetscCall(VecAXPBY(p_j, alpha, 0.0, B0s_j));
108: PetscCall(MatLMVMBasisRestoreVecRead(B, B0S_t, j, &B0s_j, &alpha));
110: // Use the matmult kernel to compute p_j = B_j * p_j
111: PetscCall(LMBasisGetVecRead(S, j, &s_j));
112: // if j == oldest p_j is already correct
113: if (j > oldest) PetscCall(SR1Kernel_Recursive_Inner(B, mode, oldest, j, s_j, p_j));
114: PetscCall(LMBasisGetVecRead(Y, j, &y_j));
115: PetscCall(VecAYPX(p_j, -1.0, y_j));
116: PetscCall(VecDot(s_j, p_j, &ymbksts));
117: PetscCall(LMProductsInsertNextDiagonalValue(YtS_minus_StBkS, j, ymbksts));
118: PetscCall(LMBasisRestoreVecRead(S, j, &s_j));
119: PetscCall(LMBasisRestoreVecRead(Y, j, &y_j));
120: PetscCall(LMBasisSetNextVec(Y_minus_BkS, p_j));
121: PetscCall(LMBasisRestoreWorkVec(Y_minus_BkS, &p_j));
122: }
123: PetscFunctionReturn(PETSC_SUCCESS);
124: }
126: static PetscErrorCode SR1Kernel_Recursive(Mat B, MatLMVMMode mode, Vec X, Vec BX)
127: {
128: PetscInt oldest, next;
130: PetscFunctionBegin;
131: PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, BX));
132: PetscCall(MatLMVMGetRange(B, &oldest, &next));
133: if (next > oldest) {
134: PetscCall(SR1RecursiveBasisUpdate(B, mode));
135: PetscCall(SR1Kernel_Recursive_Inner(B, mode, oldest, next, X, BX));
136: }
137: PetscFunctionReturn(PETSC_SUCCESS);
138: }
140: /* The SR1 kernel can be written as (See Byrd, Schnabel & Nocedal 1994)
142: B_{k+1} = B_0 + (Y - B_0 S) (diag(S^T Y) + stril(S^T Y) + stril(S^T Y)^T - S^T B_0 S)^{-1} (Y - B_0 S)^T
143: \___________________________ ___________________________/
144: V
145: M
147: M is symmetric indefinite (stril is the strictly lower triangular part)
149: M can be computed by computed triu((Y - B_0 S)^T S) and filling in the lower triangle
150: */
152: static PetscErrorCode SR1CompactProductsUpdate(Mat B, MatLMVMMode mode)
153: {
154: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
155: Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
156: MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
157: MatLMVMBasisType YmB0S_t = LMVMModeMap(LMBASIS_Y_MINUS_B0S, mode);
158: SR1ProductsType YtS_minus_StB0S_t = LMVMModeMap(SR1_PRODUCTS_YTS_MINUS_STB0S, mode);
159: LMProducts YtS_minus_StB0S;
160: Mat local;
161: PetscInt oldest, next, k;
162: PetscBool local_is_nonempty;
164: PetscFunctionBegin;
165: if (!lsr1->products[YtS_minus_StB0S_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_FULL, &lsr1->products[YtS_minus_StB0S_t]));
166: YtS_minus_StB0S = lsr1->products[YtS_minus_StB0S_t];
167: PetscCall(MatLMVMGetRange(B, &oldest, &next));
168: PetscCall(LMProductsPrepare(YtS_minus_StB0S, lmvm->J0, oldest, next));
169: PetscCall(LMProductsGetLocalMatrix(YtS_minus_StB0S, &local, &k, &local_is_nonempty));
170: if (YtS_minus_StB0S->k < next) {
171: // copy to factor in place
172: LMProducts YmB0StS;
173: Mat ymb0sts_local;
175: PetscCall(PetscCitationsRegister(ByrdNocedalSchnabelCitation, &ByrdNocedalSchnabelCite));
176: YtS_minus_StB0S->k = next;
177: PetscCall(MatLMVMGetUpdatedProducts(B, YmB0S_t, S_t, LMBLOCK_UPPER_TRIANGLE, &YmB0StS));
178: PetscCall(LMProductsGetLocalMatrix(YmB0StS, &ymb0sts_local, NULL, NULL));
179: if (local_is_nonempty) {
180: PetscErrorCode ierr;
182: PetscCall(MatSetUnfactored(local));
183: PetscCall(MatCopy(ymb0sts_local, local, SAME_NONZERO_PATTERN));
184: PetscCall(LMProductsMakeHermitian(local, oldest, next));
185: PetscCall(LMProductsOnesOnUnusedDiagonal(local, oldest, next));
186: PetscCall(MatSetOption(local, MAT_HERMITIAN, PETSC_TRUE));
187: // Set not spd so that "Cholesky" factorization is actually the symmetric indefinite Bunch Kaufman factorization
188: PetscCall(MatSetOption(local, MAT_SPD, PETSC_FALSE));
190: PetscCall(PetscPushErrorHandler(PetscReturnErrorHandler, NULL));
191: ierr = MatCholeskyFactor(local, NULL, NULL);
192: PetscCall(PetscPopErrorHandler());
193: PetscCheck(ierr == PETSC_SUCCESS || ierr == PETSC_ERR_SUP, PETSC_COMM_SELF, ierr, "Error in Bunch-Kaufman factorization");
194: // cusolver does not provide Bunch Kaufman, resort to LU if it is unavailable
195: if (ierr == PETSC_ERR_SUP) PetscCall(MatLUFactor(local, NULL, NULL, NULL));
196: }
197: PetscCall(LMProductsRestoreLocalMatrix(YmB0StS, &ymb0sts_local, NULL));
198: }
199: PetscCall(LMProductsRestoreLocalMatrix(YtS_minus_StB0S, &local, &next));
200: PetscFunctionReturn(PETSC_SUCCESS);
201: }
203: static PetscErrorCode SR1Kernel_CompactDense(Mat B, MatLMVMMode mode, Vec X, Vec BX)
204: {
205: PetscInt oldest, next;
207: PetscFunctionBegin;
208: PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, BX));
209: PetscCall(MatLMVMGetRange(B, &oldest, &next));
210: if (next > oldest) {
211: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
212: Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
213: MatLMVMBasisType Y_minus_B0S_t = LMVMModeMap(LMBASIS_Y_MINUS_B0S, mode);
214: SR1ProductsType YtS_minus_StB0S_t = LMVMModeMap(SR1_PRODUCTS_YTS_MINUS_STB0S, mode);
215: LMProducts YtS_minus_StB0S;
216: Vec YmB0StX, v;
218: PetscCall(SR1CompactProductsUpdate(B, mode));
219: YtS_minus_StB0S = lsr1->products[YtS_minus_StB0S_t];
220: PetscCall(MatLMVMGetWorkRow(B, &YmB0StX));
221: PetscCall(MatLMVMGetWorkRow(B, &v));
222: if (lmvm->do_not_cache_J0_products) {
223: /* the initial (Y - B_0 S)^T x inner product can be computed as Y^T x - S^T (B_0 x)
224: if we are not caching B_0 S products */
225: MatLMVMBasisType S_t = LMVMModeMap(LMBASIS_S, mode);
226: MatLMVMBasisType Y_t = LMVMModeMap(LMBASIS_Y, mode);
227: LMBasis S, Y;
229: PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
230: PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
231: PetscCall(LMBasisGEMVH(Y, oldest, next, 1.0, X, 0.0, YmB0StX));
232: PetscCall(LMBasisGEMVH(S, oldest, next, -1.0, BX, 1.0, YmB0StX));
233: } else PetscCall(MatLMVMBasisGEMVH(B, Y_minus_B0S_t, oldest, next, 1.0, X, 0.0, YmB0StX));
234: PetscCall(LMProductsSolve(YtS_minus_StB0S, oldest, next, YmB0StX, v, PETSC_FALSE));
235: PetscCall(MatLMVMBasisGEMV(B, Y_minus_B0S_t, oldest, next, 1.0, v, 1.0, BX));
236: PetscCall(MatLMVMRestoreWorkRow(B, &v));
237: PetscCall(MatLMVMRestoreWorkRow(B, &YmB0StX));
238: }
239: PetscFunctionReturn(PETSC_SUCCESS);
240: }
242: static PetscErrorCode MatMult_LMVMSR1_CompactDense(Mat B, Vec X, Vec BX)
243: {
244: PetscFunctionBegin;
245: PetscCall(SR1Kernel_CompactDense(B, MATLMVM_MODE_PRIMAL, X, BX));
246: PetscFunctionReturn(PETSC_SUCCESS);
247: }
249: static PetscErrorCode MatSolve_LMVMSR1_CompactDense(Mat B, Vec X, Vec BX)
250: {
251: PetscFunctionBegin;
252: PetscCall(SR1Kernel_CompactDense(B, MATLMVM_MODE_DUAL, X, BX));
253: PetscFunctionReturn(PETSC_SUCCESS);
254: }
256: static PetscErrorCode MatMult_LMVMSR1_Recursive(Mat B, Vec X, Vec Z)
257: {
258: PetscFunctionBegin;
259: PetscCall(SR1Kernel_Recursive(B, MATLMVM_MODE_PRIMAL, X, Z));
260: PetscFunctionReturn(PETSC_SUCCESS);
261: }
263: static PetscErrorCode MatSolve_LMVMSR1_Recursive(Mat B, Vec F, Vec dX)
264: {
265: PetscFunctionBegin;
266: PetscCall(SR1Kernel_Recursive(B, MATLMVM_MODE_DUAL, F, dX));
267: PetscFunctionReturn(PETSC_SUCCESS);
268: }
270: static PetscErrorCode MatUpdate_LMVMSR1(Mat B, Vec X, Vec F)
271: {
272: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
273: Mat_LSR1 *sr1 = (Mat_LSR1 *)lmvm->ctx;
274: PetscBool cache_SmH0YtF = (lmvm->mult_alg != MAT_LMVM_MULT_RECURSIVE && !lmvm->do_not_cache_J0_products) ? lmvm->cache_gradient_products : PETSC_FALSE;
276: PetscFunctionBegin;
277: if (!lmvm->m) PetscFunctionReturn(PETSC_SUCCESS);
278: if (lmvm->prev_set) {
279: PetscReal snorm, pnorm;
280: PetscScalar sktw;
281: Vec work;
282: Vec Fprev_old = NULL;
283: Vec SmH0YtFprev_old = NULL;
284: LMProducts SmH0YtY = NULL;
285: PetscInt oldest, next;
286: LMBasis SmH0Y = NULL;
287: LMBasis Y;
289: PetscCall(MatLMVMGetRange(B, &oldest, &next));
290: if (cache_SmH0YtF) {
291: PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_S_MINUS_H0Y, &SmH0Y, NULL, NULL));
292: if (!sr1->SmH0YtFprev) PetscCall(LMBasisCreateRow(SmH0Y, &sr1->SmH0YtFprev));
293: PetscCall(LMBasisGetWorkVec(SmH0Y, &Fprev_old));
294: PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_S_MINUS_H0Y, LMBASIS_Y, LMBLOCK_UPPER_TRIANGLE, &SmH0YtY));
295: PetscCall(LMProductsGetNextColumn(SmH0YtY, &SmH0YtFprev_old));
296: PetscCall(VecCopy(lmvm->Fprev, Fprev_old));
297: if (sr1->SmH0YtFprev == SmH0Y->cached_product) {
298: PetscCall(VecCopy(sr1->SmH0YtFprev, SmH0YtFprev_old));
299: } else {
300: if (next > oldest) {
301: // need to recalculate
302: PetscCall(LMBasisGEMVH(SmH0Y, oldest, next, 1.0, Fprev_old, 0.0, SmH0YtFprev_old));
303: } else {
304: PetscCall(VecZeroEntries(SmH0YtFprev_old));
305: }
306: }
307: }
309: /* Compute the new (S = X - Xprev) and (Y = F - Fprev) vectors */
310: PetscCall(VecAYPX(lmvm->Xprev, -1.0, X));
311: PetscCall(VecAYPX(lmvm->Fprev, -1.0, F));
313: /* See if the updates can be accepted
314: NOTE: This tests abs(S[k]^T (Y[k] - B_k*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_k*S[k])
316: Note that this test is flawed because this is a **limited memory** SR1 method: we are testing
318: abs(S[k]^T (Y[k] - B_{k,m}*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_{k,m}*S[k])
320: when the oldest pair of vectors in the definition of B_{k,m}, (s_{k-m}, y_{k-m}), will be dropped if we add a new
321: pair. To really ensure that B_{k+1} = B_{k+1,m} is nonsingular, you need to test
323: abs(S[k]^T (Y[k] - B_{k,m-1}*S[k])) >= eps * norm(S[k]) * norm(Y[k] - B_{k,m-1}*S[k])
325: But the product B_{k,m-1}*S[k] is not readily computable (see e.g. Lu, Xuehua, "A study of the limited memory SR1
326: method in practice", 1996).
327: */
328: PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_Y, &Y, NULL, NULL));
329: PetscCall(LMBasisGetWorkVec(Y, &work));
330: PetscCall(MatMult(B, lmvm->Xprev, work));
331: PetscCall(VecAYPX(work, -1.0, lmvm->Fprev));
332: PetscCall(VecDot(lmvm->Xprev, work, &sktw));
333: PetscCall(VecNorm(lmvm->Xprev, NORM_2, &snorm));
334: PetscCall(VecNorm(work, NORM_2, &pnorm));
335: PetscCall(LMBasisRestoreWorkVec(Y, &work));
336: if (PetscAbsReal(PetscRealPart(sktw)) >= lmvm->eps * snorm * pnorm) {
337: /* Update is good, accept it */
338: PetscCall(MatUpdateKernel_LMVM(B, lmvm->Xprev, lmvm->Fprev));
339: if (cache_SmH0YtF) {
340: PetscInt oldest_new, next_new;
342: PetscCall(MatLMVMGetUpdatedBasis(B, LMBASIS_S_MINUS_H0Y, &SmH0Y, NULL, NULL));
343: PetscCall(MatLMVMGetRange(B, &oldest_new, &next_new));
344: PetscCall(LMBasisGEMVH(SmH0Y, next, next_new, 1.0, Fprev_old, 0.0, SmH0YtFprev_old));
345: PetscCall(LMBasisGEMVH(SmH0Y, oldest_new, next_new, 1.0, F, 0.0, sr1->SmH0YtFprev));
346: PetscCall(LMBasisSetCachedProduct(SmH0Y, F, sr1->SmH0YtFprev));
347: PetscCall(VecAXPBY(SmH0YtFprev_old, 1.0, -1.0, sr1->SmH0YtFprev));
348: PetscCall(LMProductsRestoreNextColumn(SmH0YtY, &SmH0YtFprev_old));
349: }
350: } else {
351: /* Update is bad, skip it */
352: lmvm->nrejects++;
353: if (cache_SmH0YtF) {
354: // we still need to update the cached product
355: PetscCall(LMBasisGEMVH(SmH0Y, oldest, next, 1.0, F, 0.0, sr1->SmH0YtFprev));
356: PetscCall(LMBasisSetCachedProduct(SmH0Y, F, sr1->SmH0YtFprev));
357: }
358: }
359: if (cache_SmH0YtF) PetscCall(LMBasisRestoreWorkVec(SmH0Y, &Fprev_old));
360: }
361: /* Save the solution and function to be used in the next update */
362: PetscCall(VecCopy(X, lmvm->Xprev));
363: PetscCall(VecCopy(F, lmvm->Fprev));
364: lmvm->prev_set = PETSC_TRUE;
365: PetscFunctionReturn(PETSC_SUCCESS);
366: }
368: static PetscErrorCode MatReset_LMVMSR1(Mat B, MatLMVMResetMode mode)
369: {
370: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
371: Mat_LSR1 *lsr1 = (Mat_LSR1 *)lmvm->ctx;
373: PetscFunctionBegin;
374: if (MatLMVMResetClearsBases(mode)) {
375: for (PetscInt i = 0; i < SR1_BASIS_COUNT; i++) PetscCall(LMBasisDestroy(&lsr1->basis[i]));
376: for (PetscInt i = 0; i < SR1_PRODUCTS_COUNT; i++) PetscCall(LMProductsDestroy(&lsr1->products[i]));
377: PetscCall(VecDestroy(&lsr1->StFprev));
378: PetscCall(VecDestroy(&lsr1->SmH0YtFprev));
379: } else {
380: for (PetscInt i = 0; i < SR1_BASIS_COUNT; i++) PetscCall(LMBasisReset(lsr1->basis[i]));
381: for (PetscInt i = 0; i < SR1_PRODUCTS_COUNT; i++) PetscCall(LMProductsReset(lsr1->products[i]));
382: if (lsr1->StFprev) PetscCall(VecZeroEntries(lsr1->StFprev));
383: if (lsr1->SmH0YtFprev) PetscCall(VecZeroEntries(lsr1->SmH0YtFprev));
384: }
385: PetscFunctionReturn(PETSC_SUCCESS);
386: }
388: static PetscErrorCode MatDestroy_LMVMSR1(Mat B)
389: {
390: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
392: PetscFunctionBegin;
393: PetscCall(MatReset_LMVMSR1(B, MAT_LMVM_RESET_ALL));
394: PetscCall(PetscFree(lmvm->ctx));
395: PetscCall(MatDestroy_LMVM(B));
396: PetscFunctionReturn(PETSC_SUCCESS);
397: }
399: static PetscErrorCode MatLMVMSetMultAlgorithm_SR1(Mat B)
400: {
401: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
403: PetscFunctionBegin;
404: switch (lmvm->mult_alg) {
405: case MAT_LMVM_MULT_RECURSIVE:
406: lmvm->ops->mult = MatMult_LMVMSR1_Recursive;
407: lmvm->ops->solve = MatSolve_LMVMSR1_Recursive;
408: break;
409: case MAT_LMVM_MULT_DENSE:
410: case MAT_LMVM_MULT_COMPACT_DENSE:
411: lmvm->ops->mult = MatMult_LMVMSR1_CompactDense;
412: lmvm->ops->solve = MatSolve_LMVMSR1_CompactDense;
413: break;
414: }
415: lmvm->ops->multht = lmvm->ops->mult;
416: lmvm->ops->solveht = lmvm->ops->solve;
417: PetscFunctionReturn(PETSC_SUCCESS);
418: }
420: PetscErrorCode MatCreate_LMVMSR1(Mat B)
421: {
422: Mat_LMVM *lmvm;
423: Mat_LSR1 *lsr1;
425: PetscFunctionBegin;
426: PetscCall(MatCreate_LMVM(B));
427: PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMSR1));
428: PetscCall(MatSetOption(B, MAT_HERMITIAN, PETSC_TRUE));
429: B->ops->destroy = MatDestroy_LMVMSR1;
431: lmvm = (Mat_LMVM *)B->data;
432: lmvm->ops->reset = MatReset_LMVMSR1;
433: lmvm->ops->update = MatUpdate_LMVMSR1;
434: lmvm->ops->setmultalgorithm = MatLMVMSetMultAlgorithm_SR1;
435: lmvm->cache_gradient_products = PETSC_TRUE;
436: PetscCall(MatLMVMSetMultAlgorithm_SR1(B));
437: PetscCall(PetscNew(&lsr1));
438: lmvm->ctx = (void *)lsr1;
439: PetscFunctionReturn(PETSC_SUCCESS);
440: }
442: /*@
443: MatCreateLMVMSR1 - Creates a limited-memory Symmetric-Rank-1 approximation
444: matrix used for a Jacobian. L-SR1 is symmetric by construction, but is not
445: guaranteed to be positive-definite.
447: To use the L-SR1 matrix with other vector types, the matrix must be
448: created using `MatCreate()` and `MatSetType()`, followed by `MatLMVMAllocate()`.
449: This ensures that the internal storage and work vectors are duplicated from the
450: correct type of vector.
452: Collective
454: Input Parameters:
455: + comm - MPI communicator
456: . n - number of local rows for storage vectors
457: - N - global size of the storage vectors
459: Output Parameter:
460: . B - the matrix
462: Options Database Keys:
463: + -mat_lmvm_hist_size - the number of history vectors to keep
464: . -mat_lmvm_mult_algorithm - the algorithm to use for multiplication (recursive, dense, compact_dense)
465: . -mat_lmvm_cache_J0_products - whether products between the base Jacobian J0 and history vectors should be cached or recomputed
466: . -mat_lmvm_eps - (developer) numerical zero tolerance for testing when an update should be skipped
467: - -mat_lmvm_debug - (developer) perform internal debugging checks
469: Level: intermediate
471: Note:
472: It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`
473: paradigm instead of this routine directly.
475: .seealso: [](ch_ksp), `MatCreate()`, `MATLMVM`, `MATLMVMSR1`, `MatCreateLMVMBFGS()`, `MatCreateLMVMDFP()`,
476: `MatCreateLMVMBroyden()`, `MatCreateLMVMBadBroyden()`, `MatCreateLMVMSymBroyden()`
477: @*/
478: PetscErrorCode MatCreateLMVMSR1(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
479: {
480: PetscFunctionBegin;
481: PetscCall(KSPInitializePackage());
482: PetscCall(MatCreate(comm, B));
483: PetscCall(MatSetSizes(*B, n, n, N, N));
484: PetscCall(MatSetType(*B, MATLMVMSR1));
485: PetscCall(MatSetUp(*B));
486: PetscFunctionReturn(PETSC_SUCCESS);
487: }