Actual source code: brgn.c
petsc-3.14.6 2021-03-30
1: #include <../src/tao/leastsquares/impls/brgn/brgn.h>
3: #define BRGN_REGULARIZATION_USER 0
4: #define BRGN_REGULARIZATION_L2PROX 1
5: #define BRGN_REGULARIZATION_L2PURE 2
6: #define BRGN_REGULARIZATION_L1DICT 3
7: #define BRGN_REGULARIZATION_LM 4
8: #define BRGN_REGULARIZATION_TYPES 5
10: static const char *BRGN_REGULARIZATION_TABLE[64] = {"user","l2prox","l2pure","l1dict","lm"};
12: static PetscErrorCode GNHessianProd(Mat H,Vec in,Vec out)
13: {
14: TAO_BRGN *gn;
15: PetscErrorCode ierr;
18: MatShellGetContext(H,&gn);
19: MatMult(gn->subsolver->ls_jac,in,gn->r_work);
20: MatMultTranspose(gn->subsolver->ls_jac,gn->r_work,out);
21: switch (gn->reg_type) {
22: case BRGN_REGULARIZATION_USER:
23: MatMult(gn->Hreg,in,gn->x_work);
24: VecAXPY(out,gn->lambda,gn->x_work);
25: break;
26: case BRGN_REGULARIZATION_L2PURE:
27: VecAXPY(out,gn->lambda,in);
28: break;
29: case BRGN_REGULARIZATION_L2PROX:
30: VecAXPY(out,gn->lambda,in);
31: break;
32: case BRGN_REGULARIZATION_L1DICT:
33: /* out = out + lambda*D'*(diag.*(D*in)) */
34: if (gn->D) {
35: MatMult(gn->D,in,gn->y);/* y = D*in */
36: } else {
37: VecCopy(in,gn->y);
38: }
39: VecPointwiseMult(gn->y_work,gn->diag,gn->y); /* y_work = diag.*(D*in), where diag = epsilon^2 ./ sqrt(x.^2+epsilon^2).^3 */
40: if (gn->D) {
41: MatMultTranspose(gn->D,gn->y_work,gn->x_work); /* x_work = D'*(diag.*(D*in)) */
42: } else {
43: VecCopy(gn->y_work,gn->x_work);
44: }
45: VecAXPY(out,gn->lambda,gn->x_work);
46: break;
47: case BRGN_REGULARIZATION_LM:
48: VecPointwiseMult(gn->x_work,gn->damping,in);
49: VecAXPY(out,1,gn->x_work);
50: break;
51: }
52: return(0);
53: }
54: static PetscErrorCode ComputeDamping(TAO_BRGN *gn)
55: {
56: const PetscScalar *diag_ary;
57: PetscScalar *damping_ary;
58: PetscInt i,n;
59: PetscErrorCode ierr;
62: /* update damping */
63: VecGetArray(gn->damping,&damping_ary);
64: VecGetArrayRead(gn->diag,&diag_ary);
65: VecGetLocalSize(gn->damping,&n);
66: for (i=0; i<n; i++) {
67: damping_ary[i] = PetscClipInterval(diag_ary[i],PETSC_SQRT_MACHINE_EPSILON,PetscSqrtReal(PETSC_MAX_REAL));
68: }
69: VecScale(gn->damping,gn->lambda);
70: VecRestoreArray(gn->damping,&damping_ary);
71: VecRestoreArrayRead(gn->diag,&diag_ary);
72: return(0);
73: }
75: PetscErrorCode TaoBRGNGetDampingVector(Tao tao,Vec *d)
76: {
77: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
80: if (gn->reg_type != BRGN_REGULARIZATION_LM) SETERRQ(PetscObjectComm((PetscObject)tao),PETSC_ERR_SUP,"Damping vector is only available if regularization type is lm.");
81: *d = gn->damping;
82: return(0);
83: }
85: static PetscErrorCode GNObjectiveGradientEval(Tao tao,Vec X,PetscReal *fcn,Vec G,void *ptr)
86: {
87: TAO_BRGN *gn = (TAO_BRGN *)ptr;
88: PetscInt K; /* dimension of D*X */
89: PetscScalar yESum;
90: PetscErrorCode ierr;
91: PetscReal f_reg;
94: /* compute objective *fcn*/
95: /* compute first term 0.5*||ls_res||_2^2 */
96: TaoComputeResidual(tao,X,tao->ls_res);
97: VecDot(tao->ls_res,tao->ls_res,fcn);
98: *fcn *= 0.5;
99: /* compute gradient G */
100: TaoComputeResidualJacobian(tao,X,tao->ls_jac,tao->ls_jac_pre);
101: MatMultTranspose(tao->ls_jac,tao->ls_res,G);
102: /* add the regularization contribution */
103: switch (gn->reg_type) {
104: case BRGN_REGULARIZATION_USER:
105: (*gn->regularizerobjandgrad)(tao,X,&f_reg,gn->x_work,gn->reg_obj_ctx);
106: *fcn += gn->lambda*f_reg;
107: VecAXPY(G,gn->lambda,gn->x_work);
108: break;
109: case BRGN_REGULARIZATION_L2PURE:
110: /* compute f = f + lambda*0.5*xk'*xk */
111: VecDot(X,X,&f_reg);
112: *fcn += gn->lambda*0.5*f_reg;
113: /* compute G = G + lambda*xk */
114: VecAXPY(G,gn->lambda,X);
115: break;
116: case BRGN_REGULARIZATION_L2PROX:
117: /* compute f = f + lambda*0.5*(xk - xkm1)'*(xk - xkm1) */
118: VecAXPBYPCZ(gn->x_work,1.0,-1.0,0.0,X,gn->x_old);
119: VecDot(gn->x_work,gn->x_work,&f_reg);
120: *fcn += gn->lambda*0.5*f_reg;
121: /* compute G = G + lambda*(xk - xkm1) */
122: VecAXPBYPCZ(G,gn->lambda,-gn->lambda,1.0,X,gn->x_old);
123: break;
124: case BRGN_REGULARIZATION_L1DICT:
125: /* compute f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x*/
126: if (gn->D) {
127: MatMult(gn->D,X,gn->y);/* y = D*x */
128: } else {
129: VecCopy(X,gn->y);
130: }
131: VecPointwiseMult(gn->y_work,gn->y,gn->y);
132: VecShift(gn->y_work,gn->epsilon*gn->epsilon);
133: VecSqrtAbs(gn->y_work); /* gn->y_work = sqrt(y.^2+epsilon^2) */
134: VecSum(gn->y_work,&yESum);
135: VecGetSize(gn->y,&K);
136: *fcn += gn->lambda*(yESum - K*gn->epsilon);
137: /* compute G = G + lambda*D'*(y./sqrt(y.^2+epsilon^2)),where y = D*x */
138: VecPointwiseDivide(gn->y_work,gn->y,gn->y_work); /* reuse y_work = y./sqrt(y.^2+epsilon^2) */
139: if (gn->D) {
140: MatMultTranspose(gn->D,gn->y_work,gn->x_work);
141: } else {
142: VecCopy(gn->y_work,gn->x_work);
143: }
144: VecAXPY(G,gn->lambda,gn->x_work);
145: break;
146: }
147: return(0);
148: }
150: static PetscErrorCode GNComputeHessian(Tao tao,Vec X,Mat H,Mat Hpre,void *ptr)
151: {
152: TAO_BRGN *gn = (TAO_BRGN *)ptr;
153: PetscInt i,n,cstart,cend;
154: PetscScalar *cnorms,*diag_ary;
158: TaoComputeResidualJacobian(tao,X,tao->ls_jac,tao->ls_jac_pre);
160: switch (gn->reg_type) {
161: case BRGN_REGULARIZATION_USER:
162: (*gn->regularizerhessian)(tao,X,gn->Hreg,gn->reg_hess_ctx);
163: break;
164: case BRGN_REGULARIZATION_L2PURE:
165: break;
166: case BRGN_REGULARIZATION_L2PROX:
167: break;
168: case BRGN_REGULARIZATION_L1DICT:
169: /* calculate and store diagonal matrix as a vector: diag = epsilon^2 ./ sqrt(x.^2+epsilon^2).^3* --> diag = epsilon^2 ./ sqrt(y.^2+epsilon^2).^3,where y = D*x */
170: if (gn->D) {
171: MatMult(gn->D,X,gn->y);/* y = D*x */
172: } else {
173: VecCopy(X,gn->y);
174: }
175: VecPointwiseMult(gn->y_work,gn->y,gn->y);
176: VecShift(gn->y_work,gn->epsilon*gn->epsilon);
177: VecCopy(gn->y_work,gn->diag); /* gn->diag = y.^2+epsilon^2 */
178: VecSqrtAbs(gn->y_work); /* gn->y_work = sqrt(y.^2+epsilon^2) */
179: VecPointwiseMult(gn->diag,gn->y_work,gn->diag);/* gn->diag = sqrt(y.^2+epsilon^2).^3 */
180: VecReciprocal(gn->diag);
181: VecScale(gn->diag,gn->epsilon*gn->epsilon);
182: break;
183: case BRGN_REGULARIZATION_LM:
184: /* compute diagonal of J^T J */
185: MatGetSize(gn->parent->ls_jac,NULL,&n);
186: PetscMalloc1(n,&cnorms);
187: MatGetColumnNorms(gn->parent->ls_jac,NORM_2,cnorms);
188: MatGetOwnershipRangeColumn(gn->parent->ls_jac,&cstart,&cend);
189: VecGetArray(gn->diag,&diag_ary);
190: for (i = 0; i < cend-cstart; i++) {
191: diag_ary[i] = cnorms[cstart+i] * cnorms[cstart+i];
192: }
193: VecRestoreArray(gn->diag,&diag_ary);
194: PetscFree(cnorms);
195: ComputeDamping(gn);
196: break;
197: }
198: return(0);
199: }
201: static PetscErrorCode GNHookFunction(Tao tao,PetscInt iter, void *ctx)
202: {
203: TAO_BRGN *gn = (TAO_BRGN *)ctx;
204: PetscErrorCode ierr;
207: /* Update basic tao information from the subsolver */
208: gn->parent->nfuncs = tao->nfuncs;
209: gn->parent->ngrads = tao->ngrads;
210: gn->parent->nfuncgrads = tao->nfuncgrads;
211: gn->parent->nhess = tao->nhess;
212: gn->parent->niter = tao->niter;
213: gn->parent->ksp_its = tao->ksp_its;
214: gn->parent->ksp_tot_its = tao->ksp_tot_its;
215: gn->parent->fc = tao->fc;
216: TaoGetConvergedReason(tao,&gn->parent->reason);
217: /* Update the solution vectors */
218: if (iter == 0) {
219: VecSet(gn->x_old,0.0);
220: } else {
221: VecCopy(tao->solution,gn->x_old);
222: VecCopy(tao->solution,gn->parent->solution);
223: }
224: /* Update the gradient */
225: VecCopy(tao->gradient,gn->parent->gradient);
227: /* Update damping parameter for LM */
228: if (gn->reg_type == BRGN_REGULARIZATION_LM) {
229: if (iter > 0) {
230: if (gn->fc_old > tao->fc) {
231: gn->lambda = gn->lambda * gn->downhill_lambda_change;
232: } else {
233: /* uphill step */
234: gn->lambda = gn->lambda * gn->uphill_lambda_change;
235: }
236: }
237: gn->fc_old = tao->fc;
238: }
240: /* Call general purpose update function */
241: if (gn->parent->ops->update) {
242: (*gn->parent->ops->update)(gn->parent,gn->parent->niter,gn->parent->user_update);
243: }
244: return(0);
245: }
247: static PetscErrorCode TaoSolve_BRGN(Tao tao)
248: {
249: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
250: PetscErrorCode ierr;
253: TaoSolve(gn->subsolver);
254: /* Update basic tao information from the subsolver */
255: tao->nfuncs = gn->subsolver->nfuncs;
256: tao->ngrads = gn->subsolver->ngrads;
257: tao->nfuncgrads = gn->subsolver->nfuncgrads;
258: tao->nhess = gn->subsolver->nhess;
259: tao->niter = gn->subsolver->niter;
260: tao->ksp_its = gn->subsolver->ksp_its;
261: tao->ksp_tot_its = gn->subsolver->ksp_tot_its;
262: TaoGetConvergedReason(gn->subsolver,&tao->reason);
263: /* Update vectors */
264: VecCopy(gn->subsolver->solution,tao->solution);
265: VecCopy(gn->subsolver->gradient,tao->gradient);
266: return(0);
267: }
269: static PetscErrorCode TaoSetFromOptions_BRGN(PetscOptionItems *PetscOptionsObject,Tao tao)
270: {
271: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
272: TaoLineSearch ls;
273: PetscErrorCode ierr;
276: PetscOptionsHead(PetscOptionsObject,"least-squares problems with regularizer: ||f(x)||^2 + lambda*g(x), g(x) = ||xk-xkm1||^2 or ||Dx||_1 or user defined function.");
277: PetscOptionsReal("-tao_brgn_regularizer_weight","regularizer weight (default 1e-4)","",gn->lambda,&gn->lambda,NULL);
278: PetscOptionsReal("-tao_brgn_l1_smooth_epsilon","L1-norm smooth approximation parameter: ||x||_1 = sum(sqrt(x.^2+epsilon^2)-epsilon) (default 1e-6)","",gn->epsilon,&gn->epsilon,NULL);
279: PetscOptionsReal("-tao_brgn_lm_downhill_lambda_change","Factor to decrease trust region by on downhill steps","",gn->downhill_lambda_change,&gn->downhill_lambda_change,NULL);
280: PetscOptionsReal("-tao_brgn_lm_uphill_lambda_change","Factor to increase trust region by on uphill steps","",gn->uphill_lambda_change,&gn->uphill_lambda_change,NULL);
281: PetscOptionsEList("-tao_brgn_regularization_type","regularization type", "",BRGN_REGULARIZATION_TABLE,BRGN_REGULARIZATION_TYPES,BRGN_REGULARIZATION_TABLE[gn->reg_type],&gn->reg_type,NULL);
282: PetscOptionsTail();
283: /* set unit line search direction as the default when using the lm regularizer */
284: if (gn->reg_type == BRGN_REGULARIZATION_LM) {
285: TaoGetLineSearch(gn->subsolver,&ls);
286: TaoLineSearchSetType(ls,TAOLINESEARCHUNIT);
287: }
288: TaoSetFromOptions(gn->subsolver);
289: return(0);
290: }
292: static PetscErrorCode TaoView_BRGN(Tao tao,PetscViewer viewer)
293: {
294: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
295: PetscErrorCode ierr;
298: PetscViewerASCIIPushTab(viewer);
299: TaoView(gn->subsolver,viewer);
300: PetscViewerASCIIPopTab(viewer);
301: return(0);
302: }
304: static PetscErrorCode TaoSetUp_BRGN(Tao tao)
305: {
306: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
307: PetscErrorCode ierr;
308: PetscBool is_bnls,is_bntr,is_bntl;
309: PetscInt i,n,N,K; /* dict has size K*N*/
312: if (!tao->ls_res) SETERRQ(PetscObjectComm((PetscObject)tao),PETSC_ERR_ORDER,"TaoSetResidualRoutine() must be called before setup!");
313: PetscObjectTypeCompare((PetscObject)gn->subsolver,TAOBNLS,&is_bnls);
314: PetscObjectTypeCompare((PetscObject)gn->subsolver,TAOBNTR,&is_bntr);
315: PetscObjectTypeCompare((PetscObject)gn->subsolver,TAOBNTL,&is_bntl);
316: if ((is_bnls || is_bntr || is_bntl) && !tao->ls_jac) SETERRQ(PetscObjectComm((PetscObject)tao),PETSC_ERR_ORDER,"TaoSetResidualJacobianRoutine() must be called before setup!");
317: if (!tao->gradient) {
318: VecDuplicate(tao->solution,&tao->gradient);
319: }
320: if (!gn->x_work) {
321: VecDuplicate(tao->solution,&gn->x_work);
322: }
323: if (!gn->r_work) {
324: VecDuplicate(tao->ls_res,&gn->r_work);
325: }
326: if (!gn->x_old) {
327: VecDuplicate(tao->solution,&gn->x_old);
328: VecSet(gn->x_old,0.0);
329: }
331: if (BRGN_REGULARIZATION_L1DICT == gn->reg_type) {
332: if (gn->D) {
333: MatGetSize(gn->D,&K,&N); /* Shell matrices still must have sizes defined. K = N for identity matrix, K=N-1 or N for gradient matrix */
334: } else {
335: VecGetSize(tao->solution,&K); /* If user does not setup dict matrix, use identiy matrix, K=N */
336: }
337: if (!gn->y) {
338: VecCreate(PETSC_COMM_SELF,&gn->y);
339: VecSetSizes(gn->y,PETSC_DECIDE,K);
340: VecSetFromOptions(gn->y);
341: VecSet(gn->y,0.0);
343: }
344: if (!gn->y_work) {
345: VecDuplicate(gn->y,&gn->y_work);
346: }
347: if (!gn->diag) {
348: VecDuplicate(gn->y,&gn->diag);
349: VecSet(gn->diag,0.0);
350: }
351: }
352: if (BRGN_REGULARIZATION_LM == gn->reg_type) {
353: if (!gn->diag) {
354: MatCreateVecs(gn->parent->ls_jac,&gn->diag,NULL);
355: }
356: if (!gn->damping) {
357: MatCreateVecs(gn->parent->ls_jac,&gn->damping,NULL);
358: }
359: }
361: if (!tao->setupcalled) {
362: /* Hessian setup */
363: VecGetLocalSize(tao->solution,&n);
364: VecGetSize(tao->solution,&N);
365: MatSetSizes(gn->H,n,n,N,N);
366: MatSetType(gn->H,MATSHELL);
367: MatSetUp(gn->H);
368: MatShellSetOperation(gn->H,MATOP_MULT,(void (*)(void))GNHessianProd);
369: MatShellSetContext(gn->H,(void*)gn);
370: /* Subsolver setup,include initial vector and dicttionary D */
371: TaoSetUpdate(gn->subsolver,GNHookFunction,(void*)gn);
372: TaoSetInitialVector(gn->subsolver,tao->solution);
373: if (tao->bounded) {
374: TaoSetVariableBounds(gn->subsolver,tao->XL,tao->XU);
375: }
376: TaoSetResidualRoutine(gn->subsolver,tao->ls_res,tao->ops->computeresidual,tao->user_lsresP);
377: TaoSetJacobianResidualRoutine(gn->subsolver,tao->ls_jac,tao->ls_jac,tao->ops->computeresidualjacobian,tao->user_lsjacP);
378: TaoSetObjectiveAndGradientRoutine(gn->subsolver,GNObjectiveGradientEval,(void*)gn);
379: TaoSetHessianRoutine(gn->subsolver,gn->H,gn->H,GNComputeHessian,(void*)gn);
380: /* Propagate some options down */
381: TaoSetTolerances(gn->subsolver,tao->gatol,tao->grtol,tao->gttol);
382: TaoSetMaximumIterations(gn->subsolver,tao->max_it);
383: TaoSetMaximumFunctionEvaluations(gn->subsolver,tao->max_funcs);
384: for (i=0; i<tao->numbermonitors; ++i) {
385: TaoSetMonitor(gn->subsolver,tao->monitor[i],tao->monitorcontext[i],tao->monitordestroy[i]);
386: PetscObjectReference((PetscObject)(tao->monitorcontext[i]));
387: }
388: TaoSetUp(gn->subsolver);
389: }
390: return(0);
391: }
393: static PetscErrorCode TaoDestroy_BRGN(Tao tao)
394: {
395: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
396: PetscErrorCode ierr;
399: if (tao->setupcalled) {
400: VecDestroy(&tao->gradient);
401: VecDestroy(&gn->x_work);
402: VecDestroy(&gn->r_work);
403: VecDestroy(&gn->x_old);
404: VecDestroy(&gn->diag);
405: VecDestroy(&gn->y);
406: VecDestroy(&gn->y_work);
407: }
408: VecDestroy(&gn->damping);
409: VecDestroy(&gn->diag);
410: MatDestroy(&gn->H);
411: MatDestroy(&gn->D);
412: MatDestroy(&gn->Hreg);
413: TaoDestroy(&gn->subsolver);
414: gn->parent = NULL;
415: PetscFree(tao->data);
416: return(0);
417: }
419: /*MC
420: TAOBRGN - Bounded Regularized Gauss-Newton method for solving nonlinear least-squares
421: problems with bound constraints. This algorithm is a thin wrapper around TAOBNTL
422: that constructs the Gauss-Newton problem with the user-provided least-squares
423: residual and Jacobian. The algorithm offers an L2-norm ("l2pure"), L2-norm proximal point ("l2prox")
424: regularizer, and L1-norm dictionary regularizer ("l1dict"), where we approximate the
425: L1-norm ||x||_1 by sum_i(sqrt(x_i^2+epsilon^2)-epsilon) with a small positive number epsilon.
426: Also offered is the "lm" regularizer which uses a scaled diagonal of J^T J.
427: With the "lm" regularizer, BRGN is a Levenberg-Marquardt optimizer.
428: The user can also provide own regularization function.
430: Options Database Keys:
431: + -tao_brgn_regularization_type - regularization type ("user", "l2prox", "l2pure", "l1dict", "lm") (default "l2prox")
432: . -tao_brgn_regularizer_weight - regularizer weight (default 1e-4)
433: - -tao_brgn_l1_smooth_epsilon - L1-norm smooth approximation parameter: ||x||_1 = sum(sqrt(x.^2+epsilon^2)-epsilon) (default 1e-6)
435: Level: beginner
436: M*/
437: PETSC_EXTERN PetscErrorCode TaoCreate_BRGN(Tao tao)
438: {
439: TAO_BRGN *gn;
443: PetscNewLog(tao,&gn);
445: tao->ops->destroy = TaoDestroy_BRGN;
446: tao->ops->setup = TaoSetUp_BRGN;
447: tao->ops->setfromoptions = TaoSetFromOptions_BRGN;
448: tao->ops->view = TaoView_BRGN;
449: tao->ops->solve = TaoSolve_BRGN;
451: tao->data = (void*)gn;
452: gn->reg_type = BRGN_REGULARIZATION_L2PROX;
453: gn->lambda = 1e-4;
454: gn->epsilon = 1e-6;
455: gn->downhill_lambda_change = 1./5.;
456: gn->uphill_lambda_change = 1.5;
457: gn->parent = tao;
459: MatCreate(PetscObjectComm((PetscObject)tao),&gn->H);
460: MatSetOptionsPrefix(gn->H,"tao_brgn_hessian_");
462: TaoCreate(PetscObjectComm((PetscObject)tao),&gn->subsolver);
463: TaoSetType(gn->subsolver,TAOBNLS);
464: TaoSetOptionsPrefix(gn->subsolver,"tao_brgn_subsolver_");
465: return(0);
466: }
468: /*@
469: TaoBRGNGetSubsolver - Get the pointer to the subsolver inside BRGN
471: Collective on Tao
473: Level: advanced
475: Input Parameters:
476: + tao - the Tao solver context
477: - subsolver - the Tao sub-solver context
478: @*/
479: PetscErrorCode TaoBRGNGetSubsolver(Tao tao,Tao *subsolver)
480: {
481: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
484: *subsolver = gn->subsolver;
485: return(0);
486: }
488: /*@
489: TaoBRGNSetRegularizerWeight - Set the regularizer weight for the Gauss-Newton least-squares algorithm
491: Collective on Tao
493: Input Parameters:
494: + tao - the Tao solver context
495: - lambda - L1-norm regularizer weight
497: Level: beginner
498: @*/
499: PetscErrorCode TaoBRGNSetRegularizerWeight(Tao tao,PetscReal lambda)
500: {
501: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
503: /* Initialize lambda here */
506: gn->lambda = lambda;
507: return(0);
508: }
510: /*@
511: TaoBRGNSetL1SmoothEpsilon - Set the L1-norm smooth approximation parameter for L1-regularized least-squares algorithm
513: Collective on Tao
515: Input Parameters:
516: + tao - the Tao solver context
517: - epsilon - L1-norm smooth approximation parameter
519: Level: advanced
520: @*/
521: PetscErrorCode TaoBRGNSetL1SmoothEpsilon(Tao tao,PetscReal epsilon)
522: {
523: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
525: /* Initialize epsilon here */
528: gn->epsilon = epsilon;
529: return(0);
530: }
532: /*@
533: TaoBRGNSetDictionaryMatrix - bind the dictionary matrix from user application context to gn->D, for compressed sensing (with least-squares problem)
535: Input Parameters:
536: + tao - the Tao context
537: - dict - the user specified dictionary matrix. We allow to set a null dictionary, which means identity matrix by default
539: Level: advanced
540: @*/
541: PetscErrorCode TaoBRGNSetDictionaryMatrix(Tao tao,Mat dict)
542: {
543: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
547: if (dict) {
550: PetscObjectReference((PetscObject)dict);
551: }
552: MatDestroy(&gn->D);
553: gn->D = dict;
554: return(0);
555: }
557: /*@C
558: TaoBRGNSetRegularizerObjectiveAndGradientRoutine - Sets the user-defined regularizer call-back
559: function into the algorithm.
561: Input Parameters:
562: + tao - the Tao context
563: . func - function pointer for the regularizer value and gradient evaluation
564: - ctx - user context for the regularizer
566: Level: advanced
567: @*/
568: PetscErrorCode TaoBRGNSetRegularizerObjectiveAndGradientRoutine(Tao tao,PetscErrorCode (*func)(Tao,Vec,PetscReal *,Vec,void*),void *ctx)
569: {
570: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
574: if (ctx) {
575: gn->reg_obj_ctx = ctx;
576: }
577: if (func) {
578: gn->regularizerobjandgrad = func;
579: }
580: return(0);
581: }
583: /*@C
584: TaoBRGNSetRegularizerHessianRoutine - Sets the user-defined regularizer call-back
585: function into the algorithm.
587: Input Parameters:
588: + tao - the Tao context
589: . Hreg - user-created matrix for the Hessian of the regularization term
590: . func - function pointer for the regularizer Hessian evaluation
591: - ctx - user context for the regularizer Hessian
593: Level: advanced
594: @*/
595: PetscErrorCode TaoBRGNSetRegularizerHessianRoutine(Tao tao,Mat Hreg,PetscErrorCode (*func)(Tao,Vec,Mat,void*),void *ctx)
596: {
597: TAO_BRGN *gn = (TAO_BRGN *)tao->data;
602: if (Hreg) {
605: } else SETERRQ(PetscObjectComm((PetscObject)tao),PETSC_ERR_ARG_WRONG,"NULL Hessian detected! User must provide valid Hessian for the regularizer.");
606: if (ctx) {
607: gn->reg_hess_ctx = ctx;
608: }
609: if (func) {
610: gn->regularizerhessian = func;
611: }
612: if (Hreg) {
613: PetscObjectReference((PetscObject)Hreg);
614: MatDestroy(&gn->Hreg);
615: gn->Hreg = Hreg;
616: }
617: return(0);
618: }