Actual source code: bqnk.c
1: #include <../src/tao/bound/impls/bqnk/bqnk.h>
2: #include <petscksp.h>
4: static PetscErrorCode TaoBQNKComputeHessian(Tao tao)
5: {
6: TAO_BNK *bnk = (TAO_BNK *)tao->data;
7: TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx;
8: PetscReal gnorm2, delta;
10: /* Alias the LMVM matrix into the TAO hessian */
11: if (tao->hessian) {
12: MatDestroy(&tao->hessian);
13: }
14: if (tao->hessian_pre) {
15: MatDestroy(&tao->hessian_pre);
16: }
17: PetscObjectReference((PetscObject)bqnk->B);
18: tao->hessian = bqnk->B;
19: PetscObjectReference((PetscObject)bqnk->B);
20: tao->hessian_pre = bqnk->B;
21: /* Update the Hessian with the latest solution */
22: if (bqnk->is_spd) {
23: gnorm2 = bnk->gnorm*bnk->gnorm;
24: if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON;
25: if (bnk->f == 0.0) {
26: delta = 2.0 / gnorm2;
27: } else {
28: delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2;
29: }
30: MatLMVMSymBroydenSetDelta(bqnk->B, delta);
31: }
32: MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient);
33: MatLMVMResetShift(tao->hessian);
34: /* Prepare the reduced sub-matrices for the inactive set */
35: MatDestroy(&bnk->H_inactive);
36: if (bnk->active_idx) {
37: MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive);
38: PCLMVMSetIS(bqnk->pc, bnk->inactive_idx);
39: } else {
40: PetscObjectReference((PetscObject)tao->hessian);
41: bnk->H_inactive = tao->hessian;
42: PCLMVMClearIS(bqnk->pc);
43: }
44: MatDestroy(&bnk->Hpre_inactive);
45: PetscObjectReference((PetscObject)bnk->H_inactive);
46: bnk->Hpre_inactive = bnk->H_inactive;
47: return 0;
48: }
50: static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type)
51: {
52: TAO_BNK *bnk = (TAO_BNK *)tao->data;
53: TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx;
55: TaoBNKComputeStep(tao, shift, ksp_reason, step_type);
56: if (*ksp_reason < 0) {
57: /* Krylov solver failed to converge so reset the LMVM matrix */
58: MatLMVMReset(bqnk->B, PETSC_FALSE);
59: MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient);
60: }
61: return 0;
62: }
64: PetscErrorCode TaoSolve_BQNK(Tao tao)
65: {
66: TAO_BNK *bnk = (TAO_BNK *)tao->data;
67: TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx;
68: Mat_LMVM *lmvm = (Mat_LMVM*)bqnk->B->data;
69: Mat_LMVM *J0;
70: Mat_SymBrdn *diag_ctx;
71: PetscBool flg = PETSC_FALSE;
73: if (!tao->recycle) {
74: MatLMVMReset(bqnk->B, PETSC_FALSE);
75: lmvm->nresets = 0;
76: if (lmvm->J0) {
77: PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg);
78: if (flg) {
79: J0 = (Mat_LMVM*)lmvm->J0->data;
80: J0->nresets = 0;
81: }
82: }
83: flg = PETSC_FALSE;
84: PetscObjectTypeCompareAny((PetscObject)bqnk->B, &flg, MATLMVMSYMBROYDEN, MATLMVMSYMBADBROYDEN, MATLMVMBFGS, MATLMVMDFP, "");
85: if (flg) {
86: diag_ctx = (Mat_SymBrdn*)lmvm->ctx;
87: J0 = (Mat_LMVM*)diag_ctx->D->data;
88: J0->nresets = 0;
89: }
90: }
91: (*bqnk->solve)(tao);
92: return 0;
93: }
95: PetscErrorCode TaoSetUp_BQNK(Tao tao)
96: {
97: TAO_BNK *bnk = (TAO_BNK *)tao->data;
98: TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx;
99: PetscInt n, N;
100: PetscBool is_lmvm, is_sym, is_spd;
102: TaoSetUp_BNK(tao);
103: VecGetLocalSize(tao->solution,&n);
104: VecGetSize(tao->solution,&N);
105: MatSetSizes(bqnk->B, n, n, N, N);
106: MatLMVMAllocate(bqnk->B,tao->solution,bnk->unprojected_gradient);
107: PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm);
109: MatGetOption(bqnk->B, MAT_SYMMETRIC, &is_sym);
111: MatGetOption(bqnk->B, MAT_SPD, &is_spd);
112: KSPGetPC(tao->ksp, &bqnk->pc);
113: PCSetType(bqnk->pc, PCLMVM);
114: PCLMVMSetMatLMVM(bqnk->pc, bqnk->B);
115: return 0;
116: }
118: static PetscErrorCode TaoSetFromOptions_BQNK(PetscOptionItems *PetscOptionsObject,Tao tao)
119: {
120: TAO_BNK *bnk = (TAO_BNK *)tao->data;
121: TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx;
123: TaoSetFromOptions_BNK(PetscOptionsObject,tao);
124: if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION;
125: MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix);
126: MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_");
127: MatSetFromOptions(bqnk->B);
128: MatGetOption(bqnk->B, MAT_SPD, &bqnk->is_spd);
129: return 0;
130: }
132: static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer)
133: {
134: TAO_BNK *bnk = (TAO_BNK*)tao->data;
135: TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx;
136: PetscBool isascii;
138: TaoView_BNK(tao, viewer);
139: PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii);
140: if (isascii) {
141: PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO);
142: MatView(bqnk->B, viewer);
143: PetscViewerPopFormat(viewer);
144: }
145: return 0;
146: }
148: static PetscErrorCode TaoDestroy_BQNK(Tao tao)
149: {
150: TAO_BNK *bnk = (TAO_BNK*)tao->data;
151: TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx;
153: MatDestroy(&bnk->Hpre_inactive);
154: MatDestroy(&bnk->H_inactive);
155: MatDestroy(&bqnk->B);
156: PetscFree(bnk->ctx);
157: TaoDestroy_BNK(tao);
158: return 0;
159: }
161: PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao)
162: {
163: TAO_BNK *bnk;
164: TAO_BQNK *bqnk;
166: TaoCreate_BNK(tao);
167: tao->ops->solve = TaoSolve_BQNK;
168: tao->ops->setfromoptions = TaoSetFromOptions_BQNK;
169: tao->ops->destroy = TaoDestroy_BQNK;
170: tao->ops->view = TaoView_BQNK;
171: tao->ops->setup = TaoSetUp_BQNK;
173: bnk = (TAO_BNK *)tao->data;
174: bnk->computehessian = TaoBQNKComputeHessian;
175: bnk->computestep = TaoBQNKComputeStep;
176: bnk->init_type = BNK_INIT_DIRECTION;
178: PetscNewLog(tao,&bqnk);
179: bnk->ctx = (void*)bqnk;
180: bqnk->is_spd = PETSC_TRUE;
182: MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B);
183: PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1);
184: MatSetType(bqnk->B, MATLMVMSR1);
185: return 0;
186: }
188: /*@
189: TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid
190: only for quasi-Newton family of methods.
192: Input Parameters:
193: . tao - Tao solver context
195: Output Parameters:
196: . B - LMVM matrix
198: Level: advanced
200: .seealso: TAOBQNLS, TAOBQNKLS, TAOBQNKTL, TAOBQNKTR, MATLMVM, TaoSetLMVMMatrix()
201: @*/
202: PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B)
203: {
204: TAO_BNK *bnk = (TAO_BNK*)tao->data;
205: TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx;
206: PetscBool flg = PETSC_FALSE;
208: PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "");
210: *B = bqnk->B;
211: return 0;
212: }
214: /*@
215: TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid
216: only for quasi-Newton family of methods.
218: QN family of methods create their own LMVM matrices and users who wish to
219: manipulate this matrix should use TaoGetLMVMMatrix() instead.
221: Input Parameters:
222: + tao - Tao solver context
223: - B - LMVM matrix
225: Level: advanced
227: .seealso: TAOBQNLS, TAOBQNKLS, TAOBQNKTL, TAOBQNKTR, MATLMVM, TaoGetLMVMMatrix()
228: @*/
229: PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B)
230: {
231: TAO_BNK *bnk = (TAO_BNK*)tao->data;
232: TAO_BQNK *bqnk = (TAO_BQNK*)bnk->ctx;
233: PetscBool flg = PETSC_FALSE;
235: PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "");
237: PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg);
239: if (bqnk->B) {
240: MatDestroy(&bqnk->B);
241: }
242: PetscObjectReference((PetscObject)B);
243: bqnk->B = B;
244: return 0;
245: }