Actual source code: tomographyADMM.c
petsc-3.14.6 2021-03-30
1: #include <petsctao.h>
2: /*
3: Description: ADMM tomography reconstruction example .
4: 0.5*||Ax-b||^2 + lambda*g(x)
5: Reference: BRGN Tomography Example
6: */
8: static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\
9: A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\
10: We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\
11: g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\
12: natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for \n\
13: either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\
14: Then, we augment both f and g, and solve it via ADMM. \n\
15: D is the M*N transform matrix so that D*x is sparse. \n";
17: typedef struct {
18: PetscInt M,N,K,reg;
19: PetscReal lambda,eps,mumin;
20: Mat A,ATA,H,Hx,D,Hz,DTD,HF;
21: Vec c,xlb,xub,x,b,workM,workN,workN2,workN3,xGT; /* observation b, ground truth xGT, the lower bound and upper bound of x*/
22: } AppCtx;
24: /*------------------------------------------------------------*/
26: PetscErrorCode NullJacobian(Tao tao,Vec X,Mat J,Mat Jpre,void *ptr)
27: {
29: return(0);
30: }
32: /*------------------------------------------------------------*/
34: static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao)
35: {
37: PetscReal lambda, mu;
38: AppCtx *user;
39: Vec out,work,y,x;
40: Tao admm_tao,misfit;
43: user = NULL;
44: mu = 0;
45: TaoGetADMMParentTao(tao,&admm_tao);
46: TaoADMMGetMisfitSubsolver(admm_tao, &misfit);
47: TaoADMMGetSpectralPenalty(admm_tao,&mu);
48: TaoShellGetContext(tao, (void**) &user);
50: lambda = user->lambda;
51: work = user->workN;
52: TaoGetSolutionVector(tao, &out);
53: TaoGetSolutionVector(misfit, &x);
54: TaoADMMGetDualVector(admm_tao, &y);
56: /* Dx + y/mu */
57: MatMult(user->D,x,work);
58: VecAXPY(work,1/mu,y);
60: /* soft thresholding */
61: TaoSoftThreshold(work, -lambda/mu, lambda/mu, out);
62: return(0);
63: }
65: /*------------------------------------------------------------*/
67: PetscErrorCode MisfitObjectiveAndGradient(Tao tao,Vec X,PetscReal *f,Vec g,void *ptr)
68: {
69: AppCtx *user = (AppCtx*)ptr;
73: /* Objective 0.5*||Ax-b||_2^2 */
74: MatMult(user->A,X,user->workM);
75: VecAXPY(user->workM,-1,user->b);
76: VecDot(user->workM,user->workM,f);
77: *f *= 0.5;
78: /* Gradient. ATAx-ATb */
79: MatMult(user->ATA,X,user->workN);
80: MatMultTranspose(user->A,user->b,user->workN2);
81: VecWAXPY(g,-1.,user->workN2,user->workN);
82: return(0);
83: }
85: /*------------------------------------------------------------*/
87: PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao,Vec X,PetscReal *f_reg,Vec G_reg,void *ptr)
88: {
89: AppCtx *user = (AppCtx*)ptr;
93: /* compute regularizer objective
94: * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */
95: VecCopy(X,user->workN2);
96: VecPow(user->workN2,2.);
97: VecShift(user->workN2,user->eps*user->eps);
98: VecSqrtAbs(user->workN2);
99: VecCopy(user->workN2, user->workN3);
100: VecShift(user->workN2,-user->eps);
101: VecSum(user->workN2,f_reg);
102: *f_reg *= user->lambda;
103: /* compute regularizer gradient = lambda*x */
104: VecPointwiseDivide(G_reg,X,user->workN3);
105: VecScale(G_reg,user->lambda);
106: return(0);
107: }
109: /*------------------------------------------------------------*/
111: PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao,Vec X,PetscReal *f_reg,Vec G_reg,void *ptr)
112: {
113: AppCtx *user = (AppCtx*)ptr;
115: PetscReal temp;
118: /* compute regularizer objective = lambda*|z|_2^2 */
119: VecDot(X,X,&temp);
120: *f_reg = 0.5*user->lambda*temp;
121: /* compute regularizer gradient = lambda*z */
122: VecCopy(X,G_reg);
123: VecScale(G_reg,user->lambda);
124: return(0);
125: }
127: /*------------------------------------------------------------*/
129: static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
130: {
132: return(0);
133: }
135: /*------------------------------------------------------------*/
137: static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
138: {
139: AppCtx *user = (AppCtx*)ptr;
143: MatMult(user->D,x,user->workN);
144: VecPow(user->workN2,2.);
145: VecShift(user->workN2,user->eps*user->eps);
146: VecSqrtAbs(user->workN2);
147: VecShift(user->workN2,-user->eps);
148: VecReciprocal(user->workN2);
149: VecScale(user->workN2,user->eps*user->eps);
150: MatDiagonalSet(H,user->workN2,INSERT_VALUES);
151: return(0);
152: }
154: /*------------------------------------------------------------*/
156: PetscErrorCode FullObjGrad(Tao tao,Vec X,PetscReal *f,Vec g,void *ptr)
157: {
158: AppCtx *user = (AppCtx*)ptr;
160: PetscReal f_reg;
163: /* Objective 0.5*||Ax-b||_2^2 + lambda*||x||_2^2*/
164: MatMult(user->A,X,user->workM);
165: VecAXPY(user->workM,-1,user->b);
166: VecDot(user->workM,user->workM,f);
167: VecNorm(X,NORM_2,&f_reg);
168: *f *= 0.5;
169: *f += user->lambda*f_reg*f_reg;
170: /* Gradient. ATAx-ATb + 2*lambda*x */
171: MatMult(user->ATA,X,user->workN);
172: MatMultTranspose(user->A,user->b,user->workN2);
173: VecWAXPY(g,-1.,user->workN2,user->workN);
174: VecAXPY(g,2*user->lambda,X);
175: return(0);
176: }
177: /*------------------------------------------------------------*/
179: static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
180: {
182: return(0);
183: }
184: /*------------------------------------------------------------*/
187: PetscErrorCode InitializeUserData(AppCtx *user)
188: {
189: char dataFile[] = "tomographyData_A_b_xGT"; /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */
190: PetscViewer fd; /* used to load data from file */
192: PetscInt k,n;
193: PetscScalar v;
196: /* Load the A matrix, b vector, and xGT vector from a binary file. */
197: PetscViewerBinaryOpen(PETSC_COMM_WORLD,dataFile,FILE_MODE_READ,&fd);
198: MatCreate(PETSC_COMM_WORLD,&user->A);
199: MatSetType(user->A,MATAIJ);
200: MatLoad(user->A,fd);
201: VecCreate(PETSC_COMM_WORLD,&user->b);
202: VecLoad(user->b,fd);
203: VecCreate(PETSC_COMM_WORLD,&user->xGT);
204: VecLoad(user->xGT,fd);
205: PetscViewerDestroy(&fd);
207: MatGetSize(user->A,&user->M,&user->N);
209: MatCreate(PETSC_COMM_WORLD,&user->D);
210: MatSetSizes(user->D,PETSC_DECIDE,PETSC_DECIDE,user->N,user->N);
211: MatSetFromOptions(user->D);
212: MatSetUp(user->D);
213: for (k=0; k<user->N; k++) {
214: v = 1.0;
215: n = k+1;
216: if (k< user->N -1) {
217: MatSetValues(user->D,1,&k,1,&n,&v,INSERT_VALUES);
218: }
219: v = -1.0;
220: MatSetValues(user->D,1,&k,1,&k,&v,INSERT_VALUES);
221: }
222: MatAssemblyBegin(user->D,MAT_FINAL_ASSEMBLY);
223: MatAssemblyEnd(user->D,MAT_FINAL_ASSEMBLY);
225: MatTransposeMatMult(user->D,user->D,MAT_INITIAL_MATRIX,PETSC_DEFAULT,&user->DTD);
227: MatCreate(PETSC_COMM_WORLD,&user->Hz);
228: MatSetSizes(user->Hz,PETSC_DECIDE,PETSC_DECIDE,user->N,user->N);
229: MatSetFromOptions(user->Hz);
230: MatSetUp(user->Hz);
231: MatAssemblyBegin(user->Hz,MAT_FINAL_ASSEMBLY);
232: MatAssemblyEnd(user->Hz,MAT_FINAL_ASSEMBLY);
234: VecCreate(PETSC_COMM_WORLD,&(user->x));
235: VecCreate(PETSC_COMM_WORLD,&(user->workM));
236: VecCreate(PETSC_COMM_WORLD,&(user->workN));
237: VecCreate(PETSC_COMM_WORLD,&(user->workN2));
238: VecSetSizes(user->x,PETSC_DECIDE,user->N);
239: VecSetSizes(user->workM,PETSC_DECIDE,user->M);
240: VecSetSizes(user->workN,PETSC_DECIDE,user->N);
241: VecSetSizes(user->workN2,PETSC_DECIDE,user->N);
242: VecSetFromOptions(user->x);
243: VecSetFromOptions(user->workM);
244: VecSetFromOptions(user->workN);
245: VecSetFromOptions(user->workN2);
247: VecDuplicate(user->workN,&(user->workN3));
248: VecDuplicate(user->x,&(user->xlb));
249: VecDuplicate(user->x,&(user->xub));
250: VecDuplicate(user->x,&(user->c));
251: VecSet(user->xlb,0.0);
252: VecSet(user->c,0.0);
253: VecSet(user->xub,PETSC_INFINITY);
255: MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->ATA));
256: MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->Hx));
257: MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->HF));
259: MatAssemblyBegin(user->ATA,MAT_FINAL_ASSEMBLY);
260: MatAssemblyEnd(user->ATA,MAT_FINAL_ASSEMBLY);
261: MatAssemblyBegin(user->Hx,MAT_FINAL_ASSEMBLY);
262: MatAssemblyEnd(user->Hx,MAT_FINAL_ASSEMBLY);
263: MatAssemblyBegin(user->HF,MAT_FINAL_ASSEMBLY);
264: MatAssemblyEnd(user->HF,MAT_FINAL_ASSEMBLY);
266: user->lambda = 1.e-8;
267: user->eps = 1.e-3;
268: user->reg = 2;
269: user->mumin = 5.e-6;
271: PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c");
272: PetscOptionsInt("-reg","Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &(user->reg), NULL);
273: PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &(user->lambda), NULL);
274: PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &(user->eps), NULL);
275: PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &(user->mumin), NULL);
276: PetscOptionsEnd();
277: return(0);
278: }
280: /*------------------------------------------------------------*/
282: PetscErrorCode DestroyContext(AppCtx *user)
283: {
287: MatDestroy(&user->A);
288: MatDestroy(&user->ATA);
289: MatDestroy(&user->Hx);
290: MatDestroy(&user->Hz);
291: MatDestroy(&user->HF);
292: MatDestroy(&user->D);
293: MatDestroy(&user->DTD);
294: VecDestroy(&user->xGT);
295: VecDestroy(&user->xlb);
296: VecDestroy(&user->xub);
297: VecDestroy(&user->b);
298: VecDestroy(&user->x);
299: VecDestroy(&user->c);
300: VecDestroy(&user->workN3);
301: VecDestroy(&user->workN2);
302: VecDestroy(&user->workN);
303: VecDestroy(&user->workM);
304: return(0);
305: }
307: /*------------------------------------------------------------*/
309: int main(int argc,char **argv)
310: {
312: Tao tao,misfit,reg;
313: PetscReal v1,v2;
314: AppCtx* user;
315: PetscViewer fd;
316: char resultFile[] = "tomographyResult_x";
318: PetscInitialize(&argc,&argv,(char*)0,help);if (ierr) return ierr;
319: PetscNew(&user);
320: InitializeUserData(user);
322: TaoCreate(PETSC_COMM_WORLD, &tao);
323: TaoSetType(tao, TAOADMM);
324: TaoSetInitialVector(tao, user->x);
325: /* f(x) + g(x) for parent tao */
326: TaoADMMSetSpectralPenalty(tao,1.);
327: TaoSetObjectiveAndGradientRoutine(tao, FullObjGrad, (void*)user);
328: MatShift(user->HF,user->lambda);
329: TaoSetHessianRoutine(tao, user->HF, user->HF, HessianFull, (void*)user);
331: /* f(x) for misfit tao */
332: TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void*)user);
333: TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void*)user);
334: TaoADMMSetMisfitHessianChangeStatus(tao,PETSC_FALSE);
335: TaoADMMSetMisfitConstraintJacobian(tao,user->D,user->D,NullJacobian,(void*)user);
337: /* g(x) for regularizer tao */
338: if (user->reg == 1) {
339: TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void*)user);
340: TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void*)user);
341: TaoADMMSetRegHessianChangeStatus(tao,PETSC_TRUE);
342: } else if (user->reg == 2) {
343: TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void*)user);
344: MatShift(user->Hz,1);
345: MatScale(user->Hz,user->lambda);
346: TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void*)user);
347: TaoADMMSetRegHessianChangeStatus(tao,PETSC_TRUE);
348: } else if (user->reg != 3) SETERRQ(PETSC_COMM_WORLD, 1, "Incorrect Reg type"); /* TaoShell case */
350: /* Set type for the misfit solver */
351: TaoADMMGetMisfitSubsolver(tao, &misfit);
352: TaoADMMGetRegularizationSubsolver(tao, ®);
353: TaoSetType(misfit,TAONLS);
354: if (user->reg == 3) {
355: TaoSetType(reg,TAOSHELL);
356: TaoShellSetContext(reg, (void*) user);
357: TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold);
358: } else {
359: TaoSetType(reg,TAONLS);
360: }
361: TaoSetVariableBounds(misfit,user->xlb,user->xub);
363: /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */
364: TaoADMMSetRegularizerCoefficient(tao, user->lambda);
365: TaoADMMSetRegularizerConstraintJacobian(tao,NULL,NULL,NullJacobian,(void*)user);
366: TaoADMMSetMinimumSpectralPenalty(tao,user->mumin);
368: TaoADMMSetConstraintVectorRHS(tao,user->c);
369: TaoSetFromOptions(tao);
370: TaoSolve(tao);
372: /* Save x (reconstruction of object) vector to a binary file, which maybe read from Matlab and convert to a 2D image for comparison. */
373: PetscViewerBinaryOpen(PETSC_COMM_WORLD,resultFile,FILE_MODE_WRITE,&fd);
374: VecView(user->x,fd);
375: PetscViewerDestroy(&fd);
377: /* compute the error */
378: VecAXPY(user->x,-1,user->xGT);
379: VecNorm(user->x,NORM_2,&v1);
380: VecNorm(user->xGT,NORM_2,&v2);
381: PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1/v2));
383: /* Free TAO data structures */
384: TaoDestroy(&tao);
385: DestroyContext(user);
386: PetscFree(user);
387: PetscFinalize();
388: return ierr;
389: }
391: /*TEST
393: build:
394: requires: !complex !single !__float128 !define(PETSC_USE_64BIT_INDICES)
396: test:
397: suffix: 1
398: localrunfiles: tomographyData_A_b_xGT
399: args: -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc
401: test:
402: suffix: 2
403: localrunfiles: tomographyData_A_b_xGT
404: args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor
406: test:
407: suffix: 3
408: localrunfiles: tomographyData_A_b_xGT
409: args: -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor
411: test:
412: suffix: 4
413: localrunfiles: tomographyData_A_b_xGT
414: args: -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc
416: test:
417: suffix: 5
418: localrunfiles: tomographyData_A_b_xGT
419: args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
421: test:
422: suffix: 6
423: localrunfiles: tomographyData_A_b_xGT
424: args: -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
426: TEST*/