Actual source code: bfgs.c

  1: #include <../src/ksp/ksp/utils/lmvm/symbrdn/symbrdn.h>
  2: #include <petsc/private/vecimpl.h>
  3: #include <petscdevice.h>

  5: /* The BFGS update can be written

  7:    B_{k+1} = B_k + y_k (y_k^T s_k)^{-1} y_k^T - B_k s_k (s_k^T B_k s_k)^{-1} s_k^T B_k + y_k (y_k^T s_k)^{-1} y_k^T

  9:    Which can be unrolled as a parallel sum

 11:    B_{k+1} = B_0 + \sum_i B_i y_i (y_i^T s_i)^{-1} y_i^T - s_i (s_i^T B_i s_i)^{-1} s_i^T B_i

 13:    Once the (B_i y_i) vectors, (y_i^T s_i), and (s_i^T B_i s_i) products have been computed
 14:  */
 15: static PetscErrorCode BFGSKernel_Recursive_Inner(Mat B, MatLMVMMode mode, PetscInt oldest, PetscInt next, Vec X, Vec B0X)
 16: {
 17:   Mat_LMVM        *lmvm = (Mat_LMVM *)B->data;
 18:   Mat_SymBrdn     *lsb  = (Mat_SymBrdn *)lmvm->ctx;
 19:   MatLMVMBasisType Y_t  = LMVMModeMap(LMBASIS_Y, mode);
 20:   LMBasis          BkS  = lsb->basis[LMVMModeMap(SYMBROYDEN_BASIS_BKS, mode)];
 21:   LMProducts       YtS;
 22:   LMProducts       StBkS = lsb->products[LMVMModeMap(SYMBROYDEN_PRODUCTS_STBKS, mode)];
 23:   LMBasis          Y;
 24:   Vec              StBkX, YtX;

 26:   PetscFunctionBegin;
 27:   PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
 28:   PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_Y, LMBASIS_S, LMBLOCK_DIAGONAL, &YtS));
 29:   PetscCall(MatLMVMGetWorkRow(B, &StBkX));
 30:   PetscCall(MatLMVMGetWorkRow(B, &YtX));
 31:   PetscCall(LMBasisGEMVH(BkS, oldest, next, 1.0, X, 0.0, StBkX));
 32:   PetscCall(LMProductsSolve(StBkS, oldest, next, StBkX, StBkX, /* ^H */ PETSC_FALSE));
 33:   PetscCall(LMBasisGEMV(BkS, oldest, next, -1.0, StBkX, 1.0, B0X));
 34:   PetscCall(LMBasisGEMVH(Y, oldest, next, 1.0, X, 0.0, YtX));
 35:   PetscCall(LMProductsSolve(YtS, oldest, next, YtX, YtX, /* ^H */ PETSC_FALSE));
 36:   PetscCall(LMBasisGEMV(Y, oldest, next, 1.0, YtX, 1.0, B0X));
 37:   PetscCall(MatLMVMRestoreWorkRow(B, &YtX));
 38:   PetscCall(MatLMVMRestoreWorkRow(B, &StBkX));
 39:   PetscFunctionReturn(PETSC_SUCCESS);
 40: }

 42: /*
 43:    The B_i s_i vectors and (s_i^T B_i s_i) products are computed recursively
 44:  */
 45: static PetscErrorCode BFGSRecursiveBasisUpdate(Mat B, MatLMVMMode mode)
 46: {
 47:   Mat_LMVM              *lmvm    = (Mat_LMVM *)B->data;
 48:   Mat_SymBrdn           *lsb     = (Mat_SymBrdn *)lmvm->ctx;
 49:   MatLMVMBasisType       S_t     = LMVMModeMap(LMBASIS_S, mode);
 50:   MatLMVMBasisType       B0S_t   = LMVMModeMap(LMBASIS_B0S, mode);
 51:   SymBroydenProductsType StBkS_t = LMVMModeMap(SYMBROYDEN_PRODUCTS_STBKS, mode);
 52:   SymBroydenBasisType    BkS_t   = LMVMModeMap(SYMBROYDEN_BASIS_BKS, mode);
 53:   LMBasis                BkS;
 54:   LMProducts             StBkS, YtS;
 55:   PetscInt               oldest, start, next;
 56:   PetscInt               products_oldest;
 57:   LMBasis                S;

 59:   PetscFunctionBegin;
 60:   PetscCall(MatLMVMGetRange(B, &oldest, &next));
 61:   if (!lsb->basis[BkS_t]) PetscCall(LMBasisCreate(MatLMVMBasisSizeOf(B0S_t) == LMBASIS_S ? lmvm->Xprev : lmvm->Fprev, lmvm->m, &lsb->basis[BkS_t]));
 62:   BkS = lsb->basis[BkS_t];
 63:   if (!lsb->products[StBkS_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_DIAGONAL, &lsb->products[StBkS_t]));
 64:   StBkS = lsb->products[StBkS_t];
 65:   PetscCall(LMProductsPrepare(StBkS, lmvm->J0, oldest, next));
 66:   products_oldest = PetscMax(0, StBkS->k - lmvm->m);
 67:   if (oldest > products_oldest) {
 68:     // recursion is starting from a different starting index, it must be recomputed
 69:     StBkS->k = oldest;
 70:   }
 71:   BkS->k = start = StBkS->k;
 72:   if (start == next) PetscFunctionReturn(PETSC_SUCCESS);

 74:   PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
 75:   // make sure YtS is updated before entering the loop
 76:   PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_Y, LMBASIS_S, LMBLOCK_DIAGONAL, &YtS));
 77:   for (PetscInt j = start; j < next; j++) {
 78:     Vec         p_j, s_j, B0s_j;
 79:     PetscScalar alpha, sjtbjsj;

 81:     PetscCall(LMBasisGetWorkVec(BkS, &p_j));
 82:     // p_j starts as B_0 * s_j
 83:     PetscCall(MatLMVMBasisGetVecRead(B, B0S_t, j, &B0s_j, &alpha));
 84:     PetscCall(VecAXPBY(p_j, alpha, 0.0, B0s_j));
 85:     PetscCall(MatLMVMBasisRestoreVecRead(B, B0S_t, j, &B0s_j, &alpha));

 87:     // Use the matmult kernel to compute p_j = B_j * p_j
 88:     PetscCall(LMBasisGetVecRead(S, j, &s_j));
 89:     if (j > oldest) PetscCall(BFGSKernel_Recursive_Inner(B, mode, oldest, j, s_j, p_j));
 90:     PetscCall(VecDot(p_j, s_j, &sjtbjsj));
 91:     PetscCall(LMBasisRestoreVecRead(S, j, &s_j));
 92:     PetscCall(LMProductsInsertNextDiagonalValue(StBkS, j, sjtbjsj));
 93:     PetscCall(LMBasisSetNextVec(BkS, p_j));
 94:     PetscCall(LMBasisRestoreWorkVec(BkS, &p_j));
 95:   }
 96:   PetscFunctionReturn(PETSC_SUCCESS);
 97: }

 99: PETSC_INTERN PetscErrorCode BFGSKernel_Recursive(Mat B, MatLMVMMode mode, Vec X, Vec Y)
100: {
101:   PetscInt oldest, next;

103:   PetscFunctionBegin;
104:   PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, Y));
105:   PetscCall(MatLMVMGetRange(B, &oldest, &next));
106:   if (next > oldest) {
107:     PetscCall(BFGSRecursiveBasisUpdate(B, mode));
108:     PetscCall(BFGSKernel_Recursive_Inner(B, mode, oldest, next, X, Y));
109:   }
110:   PetscFunctionReturn(PETSC_SUCCESS);
111: }

113: static PetscErrorCode BFGSCompactDenseProductsUpdate(Mat B, MatLMVMMode mode)
114: {
115:   Mat_LMVM              *lmvm = (Mat_LMVM *)B->data;
116:   Mat_SymBrdn           *lsb  = (Mat_SymBrdn *)lmvm->ctx;
117:   PetscInt               oldest, next, k;
118:   MatLMVMBasisType       S_t   = LMVMModeMap(LMBASIS_S, mode);
119:   MatLMVMBasisType       B0S_t = LMVMModeMap(LMBASIS_B0S, mode);
120:   MatLMVMBasisType       Y_t   = LMVMModeMap(LMBASIS_Y, mode);
121:   SymBroydenProductsType M00_t = LMVMModeMap(SYMBROYDEN_PRODUCTS_M00, mode);
122:   LMProducts             M00, StB0S, YtS, D;
123:   Mat                    YtS_local, StB0S_local, M00_local;
124:   Vec                    D_local;
125:   PetscBool              local_is_nonempty;

127:   PetscFunctionBegin;
128:   PetscCall(MatLMVMGetRange(B, &oldest, &next));
129:   if (lsb->products[M00_t] && lsb->products[M00_t]->block_type != LMBLOCK_FULL) PetscCall(LMProductsDestroy(&lsb->products[M00_t]));
130:   if (!lsb->products[M00_t]) PetscCall(MatLMVMCreateProducts(B, LMBLOCK_FULL, &lsb->products[M00_t]));
131:   M00 = lsb->products[M00_t];
132:   PetscCall(LMProductsPrepare(M00, lmvm->J0, oldest, next));
133:   PetscCall(LMProductsGetLocalMatrix(M00, &M00_local, &k, &local_is_nonempty));
134:   if (k < next) {
135:     PetscCall(MatLMVMGetUpdatedProducts(B, Y_t, S_t, LMBLOCK_STRICT_UPPER_TRIANGLE, &YtS));
136:     PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_Y, LMBASIS_S, LMBLOCK_DIAGONAL, &D));
137:     PetscCall(MatLMVMGetUpdatedProducts(B, S_t, B0S_t, LMBLOCK_UPPER_TRIANGLE, &StB0S));

139:     PetscCall(LMProductsGetLocalMatrix(StB0S, &StB0S_local, NULL, NULL));
140:     PetscCall(LMProductsGetLocalMatrix(YtS, &YtS_local, NULL, NULL));
141:     PetscCall(LMProductsGetLocalDiagonal(D, &D_local));
142:     if (local_is_nonempty) {
143:       Vec invD;
144:       Mat stril_StY;

146:       PetscCall(MatSetUnfactored(M00_local));
147:       PetscCall(MatCopy(StB0S_local, M00_local, SAME_NONZERO_PATTERN));
148:       PetscCall(VecDuplicate(D_local, &invD));
149:       PetscCall(VecCopy(D_local, invD));
150:       PetscCall(VecReciprocal(invD));
151:       PetscCall(MatTranspose(YtS_local, MAT_INITIAL_MATRIX, &stril_StY));
152:       if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(stril_StY));

154:       PetscCall(MatDiagonalScale(stril_StY, NULL, invD));
155:       PetscCall(MatMatMult(stril_StY, YtS_local, MAT_REUSE_MATRIX, PETSC_DETERMINE, &M00_local));
156:       PetscCall(MatAXPY(M00_local, 1.0, StB0S_local, UNKNOWN_NONZERO_PATTERN));
157:       PetscCall(LMProductsMakeHermitian(M00_local, oldest, next));
158:       PetscCall(LMProductsOnesOnUnusedDiagonal(M00_local, oldest, next));
159:       PetscCall(MatSetOption(M00_local, MAT_HERMITIAN, PETSC_TRUE));
160:       PetscCall(MatSetOption(M00_local, MAT_SPD, PETSC_TRUE));
161:       PetscCall(MatCholeskyFactor(M00_local, NULL, NULL));
162:       PetscCall(MatDestroy(&stril_StY));
163:       PetscCall(VecDestroy(&invD));
164:     }
165:     PetscCall(LMProductsRestoreLocalDiagonal(D, &D_local));
166:     PetscCall(LMProductsRestoreLocalMatrix(YtS, &YtS_local, NULL));
167:     PetscCall(LMProductsRestoreLocalMatrix(StB0S, &StB0S_local, NULL));
168:   }
169:   PetscCall(LMProductsRestoreLocalMatrix(M00, &M00_local, &next));
170:   PetscFunctionReturn(PETSC_SUCCESS);
171: }

173: PETSC_INTERN PetscErrorCode BFGSKernel_CompactDense(Mat B, MatLMVMMode mode, Vec X, Vec BX)
174: {
175:   PetscInt oldest, next;

177:   PetscFunctionBegin;
178:   PetscCall(MatLMVMApplyJ0Mode(mode)(B, X, BX));
179:   PetscCall(MatLMVMGetRange(B, &oldest, &next));
180:   if (next > oldest) {
181:     Mat_LMVM              *lmvm  = (Mat_LMVM *)B->data;
182:     Mat_SymBrdn           *bfgs  = (Mat_SymBrdn *)lmvm->ctx;
183:     MatLMVMBasisType       S_t   = LMVMModeMap(LMBASIS_S, mode);
184:     MatLMVMBasisType       Y_t   = LMVMModeMap(LMBASIS_Y, mode);
185:     MatLMVMBasisType       B0S_t = LMVMModeMap(LMBASIS_B0S, mode);
186:     SymBroydenProductsType M00_t = LMVMModeMap(SYMBROYDEN_PRODUCTS_M00, mode);
187:     LMBasis                S, Y;
188:     PetscBool              use_B0S;
189:     Vec                    YtX, StB0X, u, v;
190:     LMProducts             M00, YtS, D;

192:     PetscCall(BFGSCompactDenseProductsUpdate(B, mode));
193:     PetscCall(MatLMVMGetUpdatedBasis(B, S_t, &S, NULL, NULL));
194:     PetscCall(MatLMVMGetUpdatedBasis(B, Y_t, &Y, NULL, NULL));
195:     PetscCall(MatLMVMGetUpdatedProducts(B, Y_t, S_t, LMBLOCK_STRICT_UPPER_TRIANGLE, &YtS));
196:     PetscCall(MatLMVMGetUpdatedProducts(B, LMBASIS_Y, LMBASIS_S, LMBLOCK_DIAGONAL, &D));
197:     M00 = bfgs->products[M00_t];

199:     PetscCall(MatLMVMGetWorkRow(B, &YtX));
200:     PetscCall(MatLMVMGetWorkRow(B, &StB0X));
201:     PetscCall(MatLMVMGetWorkRow(B, &u));
202:     PetscCall(MatLMVMGetWorkRow(B, &v));

204:     PetscCall(LMBasisGEMVH(Y, oldest, next, 1.0, X, 0.0, YtX));
205:     PetscCall(SymBroydenCompactDenseKernelUseB0S(B, mode, X, &use_B0S));
206:     if (use_B0S) PetscCall(MatLMVMBasisGEMVH(B, B0S_t, oldest, next, 1.0, X, 0.0, StB0X));
207:     else PetscCall(LMBasisGEMVH(S, oldest, next, 1.0, BX, 0.0, StB0X));

209:     PetscCall(LMProductsSolve(D, oldest, next, YtX, YtX, /* ^H */ PETSC_FALSE));
210:     PetscCall(LMProductsMult(YtS, oldest, next, 1.0, YtX, 1.0, StB0X, /* ^H */ PETSC_TRUE));
211:     PetscCall(LMProductsSolve(M00, oldest, next, StB0X, u, PETSC_FALSE));
212:     PetscCall(VecScale(u, -1.0));
213:     PetscCall(LMProductsMult(YtS, oldest, next, 1.0, u, 0.0, v, /* ^H */ PETSC_FALSE));
214:     PetscCall(LMProductsSolve(D, oldest, next, v, v, PETSC_FALSE));
215:     PetscCall(VecAXPY(v, 1.0, YtX));

217:     PetscCall(LMBasisGEMV(Y, oldest, next, 1.0, v, 1.0, BX));
218:     PetscCall(MatLMVMBasisGEMV(B, B0S_t, oldest, next, 1.0, u, 1.0, BX));

220:     PetscCall(MatLMVMRestoreWorkRow(B, &v));
221:     PetscCall(MatLMVMRestoreWorkRow(B, &u));
222:     PetscCall(MatLMVMRestoreWorkRow(B, &StB0X));
223:     PetscCall(MatLMVMRestoreWorkRow(B, &YtX));
224:   }
225:   PetscFunctionReturn(PETSC_SUCCESS);
226: }

228: static PetscErrorCode MatMult_LMVMBFGS_Recursive(Mat B, Vec X, Vec Y)
229: {
230:   PetscFunctionBegin;
231:   PetscCall(BFGSKernel_Recursive(B, MATLMVM_MODE_PRIMAL, X, Y));
232:   PetscFunctionReturn(PETSC_SUCCESS);
233: }

235: static PetscErrorCode MatMult_LMVMBFGS_CompactDense(Mat B, Vec X, Vec Y)
236: {
237:   PetscFunctionBegin;
238:   PetscCall(BFGSKernel_CompactDense(B, MATLMVM_MODE_PRIMAL, X, Y));
239:   PetscFunctionReturn(PETSC_SUCCESS);
240: }

242: static PetscErrorCode MatSolve_LMVMBFGS_Recursive(Mat B, Vec X, Vec HX)
243: {
244:   PetscFunctionBegin;
245:   PetscCall(DFPKernel_Recursive(B, MATLMVM_MODE_DUAL, X, HX));
246:   PetscFunctionReturn(PETSC_SUCCESS);
247: }

249: static PetscErrorCode MatSolve_LMVMBFGS_CompactDense(Mat B, Vec X, Vec HX)
250: {
251:   PetscFunctionBegin;
252:   PetscCall(DFPKernel_CompactDense(B, MATLMVM_MODE_DUAL, X, HX));
253:   PetscFunctionReturn(PETSC_SUCCESS);
254: }

256: static PetscErrorCode MatSolve_LMVMBFGS_Dense(Mat B, Vec X, Vec HX)
257: {
258:   PetscFunctionBegin;
259:   PetscCall(DFPKernel_Dense(B, MATLMVM_MODE_DUAL, X, HX));
260:   PetscFunctionReturn(PETSC_SUCCESS);
261: }

263: static PetscErrorCode MatSetFromOptions_LMVMBFGS(Mat B, PetscOptionItems PetscOptionsObject)
264: {
265:   Mat_LMVM    *lmvm  = (Mat_LMVM *)B->data;
266:   Mat_SymBrdn *lbfgs = (Mat_SymBrdn *)lmvm->ctx;

268:   PetscFunctionBegin;
269:   PetscCall(MatSetFromOptions_LMVM(B, PetscOptionsObject));
270:   PetscOptionsHeadBegin(PetscOptionsObject, "L-BFGS method for approximating SPD Jacobian actions (MATLMVMBFGS)");
271:   PetscCall(SymBroydenRescaleSetFromOptions(B, lbfgs->rescale, PetscOptionsObject));
272:   PetscOptionsHeadEnd();
273:   PetscFunctionReturn(PETSC_SUCCESS);
274: }

276: static PetscErrorCode MatLMVMSetMultAlgorithm_BFGS(Mat B)
277: {
278:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;

280:   PetscFunctionBegin;
281:   switch (lmvm->mult_alg) {
282:   case MAT_LMVM_MULT_RECURSIVE:
283:     lmvm->ops->mult  = MatMult_LMVMBFGS_Recursive;
284:     lmvm->ops->solve = MatSolve_LMVMBFGS_Recursive;
285:     break;
286:   case MAT_LMVM_MULT_DENSE:
287:     lmvm->ops->mult  = MatMult_LMVMBFGS_CompactDense;
288:     lmvm->ops->solve = MatSolve_LMVMBFGS_Dense;
289:     break;
290:   case MAT_LMVM_MULT_COMPACT_DENSE:
291:     lmvm->ops->mult  = MatMult_LMVMBFGS_CompactDense;
292:     lmvm->ops->solve = MatSolve_LMVMBFGS_CompactDense;
293:     break;
294:   }
295:   lmvm->ops->multht  = lmvm->ops->mult;
296:   lmvm->ops->solveht = lmvm->ops->solve;
297:   PetscFunctionReturn(PETSC_SUCCESS);
298: }

300: PetscErrorCode MatCreate_LMVMBFGS(Mat B)
301: {
302:   Mat_LMVM    *lmvm;
303:   Mat_SymBrdn *lbfgs;

305:   PetscFunctionBegin;
306:   PetscCall(MatCreate_LMVMSymBrdn(B));
307:   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATLMVMBFGS));
308:   B->ops->setfromoptions = MatSetFromOptions_LMVMBFGS;

310:   lmvm                        = (Mat_LMVM *)B->data;
311:   lmvm->ops->setmultalgorithm = MatLMVMSetMultAlgorithm_BFGS;
312:   PetscCall(MatLMVMSetMultAlgorithm_BFGS(B));

314:   lbfgs = (Mat_SymBrdn *)lmvm->ctx;

316:   lbfgs->phi_scalar = 0.0;
317:   lbfgs->psi_scalar = 1.0;
318:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatLMVMSymBroydenSetPhi_C", NULL));
319:   PetscFunctionReturn(PETSC_SUCCESS);
320: }

322: /*@
323:   MatCreateLMVMBFGS - Creates a limited-memory Broyden-Fletcher-Goldfarb-Shano (BFGS)
324:   matrix used for approximating Jacobians. L-BFGS is symmetric positive-definite by
325:   construction, and is commonly used to approximate Hessians in optimization
326:   problems.

328:   To use the L-BFGS matrix with other vector types, the matrix must be
329:   created using `MatCreate()` and `MatSetType()`, followed by `MatLMVMAllocate()`.
330:   This ensures that the internal storage and work vectors are duplicated from the
331:   correct type of vector.

333:   Collective

335:   Input Parameters:
336: + comm - MPI communicator
337: . n    - number of local rows for storage vectors
338: - N    - global size of the storage vectors

340:   Output Parameter:
341: . B - the matrix

343:   Options Database Keys:
344: + -mat_lmvm_scale_type - (developer) type of scaling applied to J0 (none, scalar, diagonal)
345: . -mat_lmvm_theta      - (developer) convex ratio between BFGS and DFP components of the diagonal J0 scaling
346: . -mat_lmvm_rho        - (developer) update limiter for the J0 scaling
347: . -mat_lmvm_alpha      - (developer) coefficient factor for the quadratic subproblem in J0 scaling
348: . -mat_lmvm_beta       - (developer) exponential factor for the diagonal J0 scaling
349: - -mat_lmvm_sigma_hist - (developer) number of past updates to use in J0 scaling

351:   Level: intermediate

353:   Note:
354:   It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`
355:   paradigm instead of this routine directly.

357: .seealso: [](ch_ksp), `MatCreate()`, `MATLMVM`, `MATLMVMBFGS`, `MatCreateLMVMDFP()`, `MatCreateLMVMSR1()`,
358:           `MatCreateLMVMBroyden()`, `MatCreateLMVMBadBroyden()`, `MatCreateLMVMSymBroyden()`
359: @*/
360: PetscErrorCode MatCreateLMVMBFGS(MPI_Comm comm, PetscInt n, PetscInt N, Mat *B)
361: {
362:   PetscFunctionBegin;
363:   PetscCall(KSPInitializePackage());
364:   PetscCall(MatCreate(comm, B));
365:   PetscCall(MatSetSizes(*B, n, n, N, N));
366:   PetscCall(MatSetType(*B, MATLMVMBFGS));
367:   PetscCall(MatSetUp(*B));
368:   PetscFunctionReturn(PETSC_SUCCESS);
369: }