Actual source code: tomographyADMM.c

  1: #include <petsctao.h>
  2: /*
  3: Description:   ADMM tomography reconstruction example .
  4:                0.5*||Ax-b||^2 + lambda*g(x)
  5: Reference:     BRGN Tomography Example
  6: */

  8: static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\
  9:                       A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\
 10:                       We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\
 11:                       g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\
 12:                       natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for  \n\
 13:                       either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\
 14:                       Then, we augment both f and g, and solve it via ADMM. \n\
 15:                       D is the M*N transform matrix so that D*x is sparse. \n";

 17: typedef struct {
 18:   PetscInt  M,N,K,reg;
 19:   PetscReal lambda,eps,mumin;
 20:   Mat       A,ATA,H,Hx,D,Hz,DTD,HF;
 21:   Vec       c,xlb,xub,x,b,workM,workN,workN2,workN3,xGT;    /* observation b, ground truth xGT, the lower bound and upper bound of x*/
 22: } AppCtx;

 24: /*------------------------------------------------------------*/

 26: PetscErrorCode NullJacobian(Tao tao,Vec X,Mat J,Mat Jpre,void *ptr)
 27: {
 29:   return(0);
 30: }

 32: /*------------------------------------------------------------*/

 34: static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao)
 35: {
 37:   PetscReal      lambda, mu;
 38:   AppCtx         *user;
 39:   Vec            out,work,y,x;
 40:   Tao            admm_tao,misfit;

 43:   user = NULL;
 44:   mu   = 0;
 45:   TaoGetADMMParentTao(tao,&admm_tao);
 46:   TaoADMMGetMisfitSubsolver(admm_tao, &misfit);
 47:   TaoADMMGetSpectralPenalty(admm_tao,&mu);
 48:   TaoShellGetContext(tao,&user);

 50:   lambda = user->lambda;
 51:   work   = user->workN;
 52:   TaoGetSolutionVector(tao, &out);
 53:   TaoGetSolutionVector(misfit, &x);
 54:   TaoADMMGetDualVector(admm_tao, &y);

 56:   /* Dx + y/mu */
 57:   MatMult(user->D,x,work);
 58:   VecAXPY(work,1/mu,y);

 60:   /* soft thresholding */
 61:   TaoSoftThreshold(work, -lambda/mu, lambda/mu, out);
 62:   return(0);
 63: }

 65: /*------------------------------------------------------------*/

 67: PetscErrorCode MisfitObjectiveAndGradient(Tao tao,Vec X,PetscReal *f,Vec g,void *ptr)
 68: {
 69:   AppCtx         *user = (AppCtx*)ptr;

 73:   /* Objective  0.5*||Ax-b||_2^2 */
 74:   MatMult(user->A,X,user->workM);
 75:   VecAXPY(user->workM,-1,user->b);
 76:   VecDot(user->workM,user->workM,f);
 77:   *f  *= 0.5;
 78:   /* Gradient. ATAx-ATb */
 79:   MatMult(user->ATA,X,user->workN);
 80:   MatMultTranspose(user->A,user->b,user->workN2);
 81:   VecWAXPY(g,-1.,user->workN2,user->workN);
 82:   return(0);
 83: }

 85: /*------------------------------------------------------------*/

 87: PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao,Vec X,PetscReal *f_reg,Vec G_reg,void *ptr)
 88: {
 89:   AppCtx         *user = (AppCtx*)ptr;

 93:   /* compute regularizer objective
 94:    * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */
 95:   VecCopy(X,user->workN2);
 96:   VecPow(user->workN2,2.);
 97:   VecShift(user->workN2,user->eps*user->eps);
 98:   VecSqrtAbs(user->workN2);
 99:   VecCopy(user->workN2, user->workN3);
100:   VecShift(user->workN2,-user->eps);
101:   VecSum(user->workN2,f_reg);
102:   *f_reg *= user->lambda;
103:   /* compute regularizer gradient = lambda*x */
104:   VecPointwiseDivide(G_reg,X,user->workN3);
105:   VecScale(G_reg,user->lambda);
106:   return(0);
107: }

109: /*------------------------------------------------------------*/

111: PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao,Vec X,PetscReal *f_reg,Vec G_reg,void *ptr)
112: {
113:   AppCtx         *user = (AppCtx*)ptr;
115:   PetscReal      temp;

118:   /* compute regularizer objective = lambda*|z|_2^2 */
119:   VecDot(X,X,&temp);
120:   *f_reg = 0.5*user->lambda*temp;
121:   /* compute regularizer gradient = lambda*z */
122:   VecCopy(X,G_reg);
123:   VecScale(G_reg,user->lambda);
124:   return(0);
125: }

127: /*------------------------------------------------------------*/

129: static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
130: {
132:   return(0);
133: }

135: /*------------------------------------------------------------*/

137: static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
138: {
139:   AppCtx         *user = (AppCtx*)ptr;

143:   MatMult(user->D,x,user->workN);
144:   VecPow(user->workN2,2.);
145:   VecShift(user->workN2,user->eps*user->eps);
146:   VecSqrtAbs(user->workN2);
147:   VecShift(user->workN2,-user->eps);
148:   VecReciprocal(user->workN2);
149:   VecScale(user->workN2,user->eps*user->eps);
150:   MatDiagonalSet(H,user->workN2,INSERT_VALUES);
151:   return(0);
152: }

154: /*------------------------------------------------------------*/

156: PetscErrorCode FullObjGrad(Tao tao,Vec X,PetscReal *f,Vec g,void *ptr)
157: {
158:   AppCtx         *user = (AppCtx*)ptr;
160:   PetscReal      f_reg;

163:   /* Objective  0.5*||Ax-b||_2^2 + lambda*||x||_2^2*/
164:   MatMult(user->A,X,user->workM);
165:   VecAXPY(user->workM,-1,user->b);
166:   VecDot(user->workM,user->workM,f);
167:   VecNorm(X,NORM_2,&f_reg);
168:   *f  *= 0.5;
169:   *f  += user->lambda*f_reg*f_reg;
170:   /* Gradient. ATAx-ATb + 2*lambda*x */
171:   MatMult(user->ATA,X,user->workN);
172:   MatMultTranspose(user->A,user->b,user->workN2);
173:   VecWAXPY(g,-1.,user->workN2,user->workN);
174:   VecAXPY(g,2*user->lambda,X);
175:   return(0);
176: }
177: /*------------------------------------------------------------*/

179: static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
180: {
182:   return(0);
183: }
184: /*------------------------------------------------------------*/

186: PetscErrorCode InitializeUserData(AppCtx *user)
187: {
188:   char           dataFile[] = "tomographyData_A_b_xGT";   /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */
189:   PetscViewer    fd;   /* used to load data from file */
191:   PetscInt       k,n;
192:   PetscScalar    v;

195:   /* Load the A matrix, b vector, and xGT vector from a binary file. */
196:   PetscViewerBinaryOpen(PETSC_COMM_WORLD,dataFile,FILE_MODE_READ,&fd);
197:   MatCreate(PETSC_COMM_WORLD,&user->A);
198:   MatSetType(user->A,MATAIJ);
199:   MatLoad(user->A,fd);
200:   VecCreate(PETSC_COMM_WORLD,&user->b);
201:   VecLoad(user->b,fd);
202:   VecCreate(PETSC_COMM_WORLD,&user->xGT);
203:   VecLoad(user->xGT,fd);
204:   PetscViewerDestroy(&fd);

206:   MatGetSize(user->A,&user->M,&user->N);

208:   MatCreate(PETSC_COMM_WORLD,&user->D);
209:   MatSetSizes(user->D,PETSC_DECIDE,PETSC_DECIDE,user->N,user->N);
210:   MatSetFromOptions(user->D);
211:   MatSetUp(user->D);
212:   for (k=0; k<user->N; k++) {
213:     v = 1.0;
214:     n = k+1;
215:     if (k< user->N -1) {
216:       MatSetValues(user->D,1,&k,1,&n,&v,INSERT_VALUES);
217:     }
218:     v    = -1.0;
219:     MatSetValues(user->D,1,&k,1,&k,&v,INSERT_VALUES);
220:   }
221:   MatAssemblyBegin(user->D,MAT_FINAL_ASSEMBLY);
222:   MatAssemblyEnd(user->D,MAT_FINAL_ASSEMBLY);

224:   MatTransposeMatMult(user->D,user->D,MAT_INITIAL_MATRIX,PETSC_DEFAULT,&user->DTD);

226:   MatCreate(PETSC_COMM_WORLD,&user->Hz);
227:   MatSetSizes(user->Hz,PETSC_DECIDE,PETSC_DECIDE,user->N,user->N);
228:   MatSetFromOptions(user->Hz);
229:   MatSetUp(user->Hz);
230:   MatAssemblyBegin(user->Hz,MAT_FINAL_ASSEMBLY);
231:   MatAssemblyEnd(user->Hz,MAT_FINAL_ASSEMBLY);

233:   VecCreate(PETSC_COMM_WORLD,&(user->x));
234:   VecCreate(PETSC_COMM_WORLD,&(user->workM));
235:   VecCreate(PETSC_COMM_WORLD,&(user->workN));
236:   VecCreate(PETSC_COMM_WORLD,&(user->workN2));
237:   VecSetSizes(user->x,PETSC_DECIDE,user->N);
238:   VecSetSizes(user->workM,PETSC_DECIDE,user->M);
239:   VecSetSizes(user->workN,PETSC_DECIDE,user->N);
240:   VecSetSizes(user->workN2,PETSC_DECIDE,user->N);
241:   VecSetFromOptions(user->x);
242:   VecSetFromOptions(user->workM);
243:   VecSetFromOptions(user->workN);
244:   VecSetFromOptions(user->workN2);

246:   VecDuplicate(user->workN,&(user->workN3));
247:   VecDuplicate(user->x,&(user->xlb));
248:   VecDuplicate(user->x,&(user->xub));
249:   VecDuplicate(user->x,&(user->c));
250:   VecSet(user->xlb,0.0);
251:   VecSet(user->c,0.0);
252:   VecSet(user->xub,PETSC_INFINITY);

254:   MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->ATA));
255:   MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->Hx));
256:   MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->HF));

258:   MatAssemblyBegin(user->ATA,MAT_FINAL_ASSEMBLY);
259:   MatAssemblyEnd(user->ATA,MAT_FINAL_ASSEMBLY);
260:   MatAssemblyBegin(user->Hx,MAT_FINAL_ASSEMBLY);
261:   MatAssemblyEnd(user->Hx,MAT_FINAL_ASSEMBLY);
262:   MatAssemblyBegin(user->HF,MAT_FINAL_ASSEMBLY);
263:   MatAssemblyEnd(user->HF,MAT_FINAL_ASSEMBLY);

265:   user->lambda = 1.e-8;
266:   user->eps    = 1.e-3;
267:   user->reg    = 2;
268:   user->mumin  = 5.e-6;

270:   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c");
271:   PetscOptionsInt("-reg","Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &(user->reg), NULL);
272:   PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &(user->lambda), NULL);
273:   PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &(user->eps), NULL);
274:   PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &(user->mumin), NULL);
275:   PetscOptionsEnd();
276:   return(0);
277: }

279: /*------------------------------------------------------------*/

281: PetscErrorCode DestroyContext(AppCtx *user)
282: {

286:   MatDestroy(&user->A);
287:   MatDestroy(&user->ATA);
288:   MatDestroy(&user->Hx);
289:   MatDestroy(&user->Hz);
290:   MatDestroy(&user->HF);
291:   MatDestroy(&user->D);
292:   MatDestroy(&user->DTD);
293:   VecDestroy(&user->xGT);
294:   VecDestroy(&user->xlb);
295:   VecDestroy(&user->xub);
296:   VecDestroy(&user->b);
297:   VecDestroy(&user->x);
298:   VecDestroy(&user->c);
299:   VecDestroy(&user->workN3);
300:   VecDestroy(&user->workN2);
301:   VecDestroy(&user->workN);
302:   VecDestroy(&user->workM);
303:   return(0);
304: }

306: /*------------------------------------------------------------*/

308: int main(int argc,char **argv)
309: {
311:   Tao            tao,misfit,reg;
312:   PetscReal      v1,v2;
313:   AppCtx*        user;
314:   PetscViewer    fd;
315:   char           resultFile[] = "tomographyResult_x";

317:   PetscInitialize(&argc,&argv,(char*)0,help);if (ierr) return ierr;
318:   PetscNew(&user);
319:   InitializeUserData(user);

321:   TaoCreate(PETSC_COMM_WORLD, &tao);
322:   TaoSetType(tao, TAOADMM);
323:   TaoSetInitialVector(tao, user->x);
324:   /* f(x) + g(x) for parent tao */
325:   TaoADMMSetSpectralPenalty(tao,1.);
326:   TaoSetObjectiveAndGradientRoutine(tao, FullObjGrad, (void*)user);
327:   MatShift(user->HF,user->lambda);
328:   TaoSetHessianRoutine(tao, user->HF, user->HF, HessianFull, (void*)user);

330:   /* f(x) for misfit tao */
331:   TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void*)user);
332:   TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void*)user);
333:   TaoADMMSetMisfitHessianChangeStatus(tao,PETSC_FALSE);
334:   TaoADMMSetMisfitConstraintJacobian(tao,user->D,user->D,NullJacobian,(void*)user);

336:   /* g(x) for regularizer tao */
337:   if (user->reg == 1) {
338:     TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void*)user);
339:     TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void*)user);
340:     TaoADMMSetRegHessianChangeStatus(tao,PETSC_TRUE);
341:   } else if (user->reg == 2) {
342:     TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void*)user);
343:     MatShift(user->Hz,1);
344:     MatScale(user->Hz,user->lambda);
345:     TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void*)user);
346:     TaoADMMSetRegHessianChangeStatus(tao,PETSC_TRUE);
347:   } else if (user->reg != 3) SETERRQ(PETSC_COMM_WORLD, PETSC_ERR_ARG_UNKNOWN_TYPE, "Incorrect Reg type"); /* TaoShell case */

349:   /* Set type for the misfit solver */
350:   TaoADMMGetMisfitSubsolver(tao, &misfit);
351:   TaoADMMGetRegularizationSubsolver(tao, &reg);
352:   TaoSetType(misfit,TAONLS);
353:   if (user->reg == 3) {
354:     TaoSetType(reg,TAOSHELL);
355:     TaoShellSetContext(reg, (void*) user);
356:     TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold);
357:   } else {
358:     TaoSetType(reg,TAONLS);
359:   }
360:   TaoSetVariableBounds(misfit,user->xlb,user->xub);

362:   /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */
363:   TaoADMMSetRegularizerCoefficient(tao, user->lambda);
364:   TaoADMMSetRegularizerConstraintJacobian(tao,NULL,NULL,NullJacobian,(void*)user);
365:   TaoADMMSetMinimumSpectralPenalty(tao,user->mumin);

367:   TaoADMMSetConstraintVectorRHS(tao,user->c);
368:   TaoSetFromOptions(tao);
369:   TaoSolve(tao);

371:   /* Save x (reconstruction of object) vector to a binary file, which maybe read from Matlab and convert to a 2D image for comparison. */
372:   PetscViewerBinaryOpen(PETSC_COMM_WORLD,resultFile,FILE_MODE_WRITE,&fd);
373:   VecView(user->x,fd);
374:   PetscViewerDestroy(&fd);

376:   /* compute the error */
377:   VecAXPY(user->x,-1,user->xGT);
378:   VecNorm(user->x,NORM_2,&v1);
379:   VecNorm(user->xGT,NORM_2,&v2);
380:   PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1/v2));

382:   /* Free TAO data structures */
383:   TaoDestroy(&tao);
384:   DestroyContext(user);
385:   PetscFree(user);
386:   PetscFinalize();
387:   return ierr;
388: }

390: /*TEST

392:    build:
393:       requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)

395:    test:
396:       suffix: 1
397:       localrunfiles: tomographyData_A_b_xGT
398:       args:  -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc

400:    test:
401:       suffix: 2
402:       localrunfiles: tomographyData_A_b_xGT
403:       args:  -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8  -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor

405:    test:
406:       suffix: 3
407:       localrunfiles: tomographyData_A_b_xGT
408:       args:  -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor

410:    test:
411:       suffix: 4
412:       localrunfiles: tomographyData_A_b_xGT
413:       args:  -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc

415:    test:
416:       suffix: 5
417:       localrunfiles: tomographyData_A_b_xGT
418:       args:  -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc

420:    test:
421:       suffix: 6
422:       localrunfiles: tomographyData_A_b_xGT
423:       args:  -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc

425: TEST*/