Actual source code: bqnk.c

  1: #include <../src/tao/bound/impls/bqnk/bqnk.h>
  2: #include <petscksp.h>

  4: static PetscErrorCode TaoBQNKComputeHessian(Tao tao)
  5: {
  6:   TAO_BNK        *bnk = (TAO_BNK *)tao->data;
  7:   TAO_BQNK       *bqnk = (TAO_BQNK*)bnk->ctx;
  8:   PetscReal      gnorm2, delta;

 10:   /* Alias the LMVM matrix into the TAO hessian */
 11:   if (tao->hessian) {
 12:     MatDestroy(&tao->hessian);
 13:   }
 14:   if (tao->hessian_pre) {
 15:     MatDestroy(&tao->hessian_pre);
 16:   }
 17:   PetscObjectReference((PetscObject)bqnk->B);
 18:   tao->hessian = bqnk->B;
 19:   PetscObjectReference((PetscObject)bqnk->B);
 20:   tao->hessian_pre = bqnk->B;
 21:   /* Update the Hessian with the latest solution */
 22:   if (bqnk->is_spd) {
 23:     gnorm2 = bnk->gnorm*bnk->gnorm;
 24:     if (gnorm2 == 0.0) gnorm2 = PETSC_MACHINE_EPSILON;
 25:     if (bnk->f == 0.0) {
 26:       delta = 2.0 / gnorm2;
 27:     } else {
 28:       delta = 2.0 * PetscAbsScalar(bnk->f) / gnorm2;
 29:     }
 30:     MatLMVMSymBroydenSetDelta(bqnk->B, delta);
 31:   }
 32:   MatLMVMUpdate(tao->hessian, tao->solution, bnk->unprojected_gradient);
 33:   MatLMVMResetShift(tao->hessian);
 34:   /* Prepare the reduced sub-matrices for the inactive set */
 35:   MatDestroy(&bnk->H_inactive);
 36:   if (bnk->active_idx) {
 37:     MatCreateSubMatrixVirtual(tao->hessian, bnk->inactive_idx, bnk->inactive_idx, &bnk->H_inactive);
 38:     PCLMVMSetIS(bqnk->pc, bnk->inactive_idx);
 39:   } else {
 40:     PetscObjectReference((PetscObject)tao->hessian);
 41:     bnk->H_inactive = tao->hessian;
 42:     PCLMVMClearIS(bqnk->pc);
 43:   }
 44:   MatDestroy(&bnk->Hpre_inactive);
 45:   PetscObjectReference((PetscObject)bnk->H_inactive);
 46:   bnk->Hpre_inactive = bnk->H_inactive;
 47:   return 0;
 48: }

 50: static PetscErrorCode TaoBQNKComputeStep(Tao tao, PetscBool shift, KSPConvergedReason *ksp_reason, PetscInt *step_type)
 51: {
 52:   TAO_BNK        *bnk = (TAO_BNK *)tao->data;
 53:   TAO_BQNK       *bqnk = (TAO_BQNK*)bnk->ctx;

 55:   TaoBNKComputeStep(tao, shift, ksp_reason, step_type);
 56:   if (*ksp_reason < 0) {
 57:     /* Krylov solver failed to converge so reset the LMVM matrix */
 58:     MatLMVMReset(bqnk->B, PETSC_FALSE);
 59:     MatLMVMUpdate(bqnk->B, tao->solution, bnk->unprojected_gradient);
 60:   }
 61:   return 0;
 62: }

 64: PetscErrorCode TaoSolve_BQNK(Tao tao)
 65: {
 66:   TAO_BNK        *bnk = (TAO_BNK *)tao->data;
 67:   TAO_BQNK       *bqnk = (TAO_BQNK*)bnk->ctx;
 68:   Mat_LMVM       *lmvm = (Mat_LMVM*)bqnk->B->data;
 69:   Mat_LMVM       *J0;
 70:   Mat_SymBrdn    *diag_ctx;
 71:   PetscBool      flg = PETSC_FALSE;

 73:   if (!tao->recycle) {
 74:     MatLMVMReset(bqnk->B, PETSC_FALSE);
 75:     lmvm->nresets = 0;
 76:     if (lmvm->J0) {
 77:       PetscObjectBaseTypeCompare((PetscObject)lmvm->J0, MATLMVM, &flg);
 78:       if (flg) {
 79:         J0 = (Mat_LMVM*)lmvm->J0->data;
 80:         J0->nresets = 0;
 81:       }
 82:     }
 83:     flg = PETSC_FALSE;
 84:     PetscObjectTypeCompareAny((PetscObject)bqnk->B, &flg, MATLMVMSYMBROYDEN, MATLMVMSYMBADBROYDEN, MATLMVMBFGS, MATLMVMDFP, "");
 85:     if (flg) {
 86:       diag_ctx = (Mat_SymBrdn*)lmvm->ctx;
 87:       J0 = (Mat_LMVM*)diag_ctx->D->data;
 88:       J0->nresets = 0;
 89:     }
 90:   }
 91:   (*bqnk->solve)(tao);
 92:   return 0;
 93: }

 95: PetscErrorCode TaoSetUp_BQNK(Tao tao)
 96: {
 97:   TAO_BNK        *bnk = (TAO_BNK *)tao->data;
 98:   TAO_BQNK       *bqnk = (TAO_BQNK*)bnk->ctx;
 99:   PetscInt       n, N;
100:   PetscBool      is_lmvm, is_sym, is_spd;

102:   TaoSetUp_BNK(tao);
103:   VecGetLocalSize(tao->solution,&n);
104:   VecGetSize(tao->solution,&N);
105:   MatSetSizes(bqnk->B, n, n, N, N);
106:   MatLMVMAllocate(bqnk->B,tao->solution,bnk->unprojected_gradient);
107:   PetscObjectBaseTypeCompare((PetscObject)bqnk->B, MATLMVM, &is_lmvm);
109:   MatGetOption(bqnk->B, MAT_SYMMETRIC, &is_sym);
111:   MatGetOption(bqnk->B, MAT_SPD, &is_spd);
112:   KSPGetPC(tao->ksp, &bqnk->pc);
113:   PCSetType(bqnk->pc, PCLMVM);
114:   PCLMVMSetMatLMVM(bqnk->pc, bqnk->B);
115:   return 0;
116: }

118: static PetscErrorCode TaoSetFromOptions_BQNK(PetscOptionItems *PetscOptionsObject,Tao tao)
119: {
120:   TAO_BNK        *bnk = (TAO_BNK *)tao->data;
121:   TAO_BQNK       *bqnk = (TAO_BQNK*)bnk->ctx;

123:   TaoSetFromOptions_BNK(PetscOptionsObject,tao);
124:   if (bnk->init_type == BNK_INIT_INTERPOLATION) bnk->init_type = BNK_INIT_DIRECTION;
125:   MatSetOptionsPrefix(bqnk->B, ((PetscObject)tao)->prefix);
126:   MatAppendOptionsPrefix(bqnk->B, "tao_bqnk_");
127:   MatSetFromOptions(bqnk->B);
128:   MatGetOption(bqnk->B, MAT_SPD, &bqnk->is_spd);
129:   return 0;
130: }

132: static PetscErrorCode TaoView_BQNK(Tao tao, PetscViewer viewer)
133: {
134:   TAO_BNK        *bnk = (TAO_BNK*)tao->data;
135:   TAO_BQNK       *bqnk = (TAO_BQNK*)bnk->ctx;
136:   PetscBool      isascii;

138:   TaoView_BNK(tao, viewer);
139:   PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii);
140:   if (isascii) {
141:     PetscViewerPushFormat(viewer, PETSC_VIEWER_ASCII_INFO);
142:     MatView(bqnk->B, viewer);
143:     PetscViewerPopFormat(viewer);
144:   }
145:   return 0;
146: }

148: static PetscErrorCode TaoDestroy_BQNK(Tao tao)
149: {
150:   TAO_BNK        *bnk = (TAO_BNK*)tao->data;
151:   TAO_BQNK       *bqnk = (TAO_BQNK*)bnk->ctx;

153:   MatDestroy(&bnk->Hpre_inactive);
154:   MatDestroy(&bnk->H_inactive);
155:   MatDestroy(&bqnk->B);
156:   PetscFree(bnk->ctx);
157:   TaoDestroy_BNK(tao);
158:   return 0;
159: }

161: PETSC_INTERN PetscErrorCode TaoCreate_BQNK(Tao tao)
162: {
163:   TAO_BNK        *bnk;
164:   TAO_BQNK       *bqnk;

166:   TaoCreate_BNK(tao);
167:   tao->ops->solve = TaoSolve_BQNK;
168:   tao->ops->setfromoptions = TaoSetFromOptions_BQNK;
169:   tao->ops->destroy = TaoDestroy_BQNK;
170:   tao->ops->view = TaoView_BQNK;
171:   tao->ops->setup = TaoSetUp_BQNK;

173:   bnk = (TAO_BNK *)tao->data;
174:   bnk->computehessian = TaoBQNKComputeHessian;
175:   bnk->computestep = TaoBQNKComputeStep;
176:   bnk->init_type = BNK_INIT_DIRECTION;

178:   PetscNewLog(tao,&bqnk);
179:   bnk->ctx = (void*)bqnk;
180:   bqnk->is_spd = PETSC_TRUE;

182:   MatCreate(PetscObjectComm((PetscObject)tao), &bqnk->B);
183:   PetscObjectIncrementTabLevel((PetscObject)bqnk->B, (PetscObject)tao, 1);
184:   MatSetType(bqnk->B, MATLMVMSR1);
185:   return 0;
186: }

188: /*@
189:    TaoGetLMVMMatrix - Returns a pointer to the internal LMVM matrix. Valid
190:    only for quasi-Newton family of methods.

192:    Input Parameters:
193: .  tao - Tao solver context

195:    Output Parameters:
196: .  B - LMVM matrix

198:    Level: advanced

200: .seealso: TAOBQNLS, TAOBQNKLS, TAOBQNKTL, TAOBQNKTR, MATLMVM, TaoSetLMVMMatrix()
201: @*/
202: PetscErrorCode TaoGetLMVMMatrix(Tao tao, Mat *B)
203: {
204:   TAO_BNK        *bnk = (TAO_BNK*)tao->data;
205:   TAO_BQNK       *bqnk = (TAO_BQNK*)bnk->ctx;
206:   PetscBool      flg = PETSC_FALSE;

208:   PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "");
210:   *B = bqnk->B;
211:   return 0;
212: }

214: /*@
215:    TaoSetLMVMMatrix - Sets an external LMVM matrix into the Tao solver. Valid
216:    only for quasi-Newton family of methods.

218:    QN family of methods create their own LMVM matrices and users who wish to
219:    manipulate this matrix should use TaoGetLMVMMatrix() instead.

221:    Input Parameters:
222: +  tao - Tao solver context
223: -  B - LMVM matrix

225:    Level: advanced

227: .seealso: TAOBQNLS, TAOBQNKLS, TAOBQNKTL, TAOBQNKTR, MATLMVM, TaoGetLMVMMatrix()
228: @*/
229: PetscErrorCode TaoSetLMVMMatrix(Tao tao, Mat B)
230: {
231:   TAO_BNK        *bnk = (TAO_BNK*)tao->data;
232:   TAO_BQNK       *bqnk = (TAO_BQNK*)bnk->ctx;
233:   PetscBool      flg = PETSC_FALSE;

235:   PetscObjectTypeCompareAny((PetscObject)tao, &flg, TAOBQNLS, TAOBQNKLS, TAOBQNKTR, TAOBQNKTL, "");
237:   PetscObjectBaseTypeCompare((PetscObject)B, MATLMVM, &flg);
239:   if (bqnk->B) {
240:     MatDestroy(&bqnk->B);
241:   }
242:   PetscObjectReference((PetscObject)B);
243:   bqnk->B = B;
244:   return 0;
245: }