Actual source code: tomographyADMM.c
1: #include <petsctao.h>
2: /*
3: Description: ADMM tomography reconstruction example .
4: 0.5*||Ax-b||^2 + lambda*g(x)
5: Reference: BRGN Tomography Example
6: */
8: static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\
9: A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\
10: We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\
11: g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\
12: natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for \n\
13: either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\
14: Then, we augment both f and g, and solve it via ADMM. \n\
15: D is the M*N transform matrix so that D*x is sparse. \n";
17: typedef struct {
18: PetscInt M,N,K,reg;
19: PetscReal lambda,eps,mumin;
20: Mat A,ATA,H,Hx,D,Hz,DTD,HF;
21: Vec c,xlb,xub,x,b,workM,workN,workN2,workN3,xGT; /* observation b, ground truth xGT, the lower bound and upper bound of x*/
22: } AppCtx;
24: /*------------------------------------------------------------*/
26: PetscErrorCode NullJacobian(Tao tao,Vec X,Mat J,Mat Jpre,void *ptr)
27: {
28: return 0;
29: }
31: /*------------------------------------------------------------*/
33: static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao)
34: {
35: PetscReal lambda, mu;
36: AppCtx *user;
37: Vec out,work,y,x;
38: Tao admm_tao,misfit;
40: user = NULL;
41: mu = 0;
42: TaoGetADMMParentTao(tao,&admm_tao);
43: TaoADMMGetMisfitSubsolver(admm_tao, &misfit);
44: TaoADMMGetSpectralPenalty(admm_tao,&mu);
45: TaoShellGetContext(tao,&user);
47: lambda = user->lambda;
48: work = user->workN;
49: TaoGetSolution(tao, &out);
50: TaoGetSolution(misfit, &x);
51: TaoADMMGetDualVector(admm_tao, &y);
53: /* Dx + y/mu */
54: MatMult(user->D,x,work);
55: VecAXPY(work,1/mu,y);
57: /* soft thresholding */
58: TaoSoftThreshold(work, -lambda/mu, lambda/mu, out);
59: return 0;
60: }
62: /*------------------------------------------------------------*/
64: PetscErrorCode MisfitObjectiveAndGradient(Tao tao,Vec X,PetscReal *f,Vec g,void *ptr)
65: {
66: AppCtx *user = (AppCtx*)ptr;
68: /* Objective 0.5*||Ax-b||_2^2 */
69: MatMult(user->A,X,user->workM);
70: VecAXPY(user->workM,-1,user->b);
71: VecDot(user->workM,user->workM,f);
72: *f *= 0.5;
73: /* Gradient. ATAx-ATb */
74: MatMult(user->ATA,X,user->workN);
75: MatMultTranspose(user->A,user->b,user->workN2);
76: VecWAXPY(g,-1.,user->workN2,user->workN);
77: return 0;
78: }
80: /*------------------------------------------------------------*/
82: PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao,Vec X,PetscReal *f_reg,Vec G_reg,void *ptr)
83: {
84: AppCtx *user = (AppCtx*)ptr;
86: /* compute regularizer objective
87: * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */
88: VecCopy(X,user->workN2);
89: VecPow(user->workN2,2.);
90: VecShift(user->workN2,user->eps*user->eps);
91: VecSqrtAbs(user->workN2);
92: VecCopy(user->workN2, user->workN3);
93: VecShift(user->workN2,-user->eps);
94: VecSum(user->workN2,f_reg);
95: *f_reg *= user->lambda;
96: /* compute regularizer gradient = lambda*x */
97: VecPointwiseDivide(G_reg,X,user->workN3);
98: VecScale(G_reg,user->lambda);
99: return 0;
100: }
102: /*------------------------------------------------------------*/
104: PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao,Vec X,PetscReal *f_reg,Vec G_reg,void *ptr)
105: {
106: AppCtx *user = (AppCtx*)ptr;
107: PetscReal temp;
109: /* compute regularizer objective = lambda*|z|_2^2 */
110: VecDot(X,X,&temp);
111: *f_reg = 0.5*user->lambda*temp;
112: /* compute regularizer gradient = lambda*z */
113: VecCopy(X,G_reg);
114: VecScale(G_reg,user->lambda);
115: return 0;
116: }
118: /*------------------------------------------------------------*/
120: static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
121: {
122: return 0;
123: }
125: /*------------------------------------------------------------*/
127: static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
128: {
129: AppCtx *user = (AppCtx*)ptr;
131: MatMult(user->D,x,user->workN);
132: VecPow(user->workN2,2.);
133: VecShift(user->workN2,user->eps*user->eps);
134: VecSqrtAbs(user->workN2);
135: VecShift(user->workN2,-user->eps);
136: VecReciprocal(user->workN2);
137: VecScale(user->workN2,user->eps*user->eps);
138: MatDiagonalSet(H,user->workN2,INSERT_VALUES);
139: return 0;
140: }
142: /*------------------------------------------------------------*/
144: PetscErrorCode FullObjGrad(Tao tao,Vec X,PetscReal *f,Vec g,void *ptr)
145: {
146: AppCtx *user = (AppCtx*)ptr;
147: PetscReal f_reg;
149: /* Objective 0.5*||Ax-b||_2^2 + lambda*||x||_2^2*/
150: MatMult(user->A,X,user->workM);
151: VecAXPY(user->workM,-1,user->b);
152: VecDot(user->workM,user->workM,f);
153: VecNorm(X,NORM_2,&f_reg);
154: *f *= 0.5;
155: *f += user->lambda*f_reg*f_reg;
156: /* Gradient. ATAx-ATb + 2*lambda*x */
157: MatMult(user->ATA,X,user->workN);
158: MatMultTranspose(user->A,user->b,user->workN2);
159: VecWAXPY(g,-1.,user->workN2,user->workN);
160: VecAXPY(g,2*user->lambda,X);
161: return 0;
162: }
163: /*------------------------------------------------------------*/
165: static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
166: {
167: return 0;
168: }
169: /*------------------------------------------------------------*/
171: PetscErrorCode InitializeUserData(AppCtx *user)
172: {
173: char dataFile[] = "tomographyData_A_b_xGT"; /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */
174: PetscViewer fd; /* used to load data from file */
176: PetscInt k,n;
177: PetscScalar v;
179: /* Load the A matrix, b vector, and xGT vector from a binary file. */
180: PetscViewerBinaryOpen(PETSC_COMM_WORLD,dataFile,FILE_MODE_READ,&fd);
181: MatCreate(PETSC_COMM_WORLD,&user->A);
182: MatSetType(user->A,MATAIJ);
183: MatLoad(user->A,fd);
184: VecCreate(PETSC_COMM_WORLD,&user->b);
185: VecLoad(user->b,fd);
186: VecCreate(PETSC_COMM_WORLD,&user->xGT);
187: VecLoad(user->xGT,fd);
188: PetscViewerDestroy(&fd);
190: MatGetSize(user->A,&user->M,&user->N);
192: MatCreate(PETSC_COMM_WORLD,&user->D);
193: MatSetSizes(user->D,PETSC_DECIDE,PETSC_DECIDE,user->N,user->N);
194: MatSetFromOptions(user->D);
195: MatSetUp(user->D);
196: for (k=0; k<user->N; k++) {
197: v = 1.0;
198: n = k+1;
199: if (k< user->N -1) {
200: MatSetValues(user->D,1,&k,1,&n,&v,INSERT_VALUES);
201: }
202: v = -1.0;
203: MatSetValues(user->D,1,&k,1,&k,&v,INSERT_VALUES);
204: }
205: MatAssemblyBegin(user->D,MAT_FINAL_ASSEMBLY);
206: MatAssemblyEnd(user->D,MAT_FINAL_ASSEMBLY);
208: MatTransposeMatMult(user->D,user->D,MAT_INITIAL_MATRIX,PETSC_DEFAULT,&user->DTD);
210: MatCreate(PETSC_COMM_WORLD,&user->Hz);
211: MatSetSizes(user->Hz,PETSC_DECIDE,PETSC_DECIDE,user->N,user->N);
212: MatSetFromOptions(user->Hz);
213: MatSetUp(user->Hz);
214: MatAssemblyBegin(user->Hz,MAT_FINAL_ASSEMBLY);
215: MatAssemblyEnd(user->Hz,MAT_FINAL_ASSEMBLY);
217: VecCreate(PETSC_COMM_WORLD,&(user->x));
218: VecCreate(PETSC_COMM_WORLD,&(user->workM));
219: VecCreate(PETSC_COMM_WORLD,&(user->workN));
220: VecCreate(PETSC_COMM_WORLD,&(user->workN2));
221: VecSetSizes(user->x,PETSC_DECIDE,user->N);
222: VecSetSizes(user->workM,PETSC_DECIDE,user->M);
223: VecSetSizes(user->workN,PETSC_DECIDE,user->N);
224: VecSetSizes(user->workN2,PETSC_DECIDE,user->N);
225: VecSetFromOptions(user->x);
226: VecSetFromOptions(user->workM);
227: VecSetFromOptions(user->workN);
228: VecSetFromOptions(user->workN2);
230: VecDuplicate(user->workN,&(user->workN3));
231: VecDuplicate(user->x,&(user->xlb));
232: VecDuplicate(user->x,&(user->xub));
233: VecDuplicate(user->x,&(user->c));
234: VecSet(user->xlb,0.0);
235: VecSet(user->c,0.0);
236: VecSet(user->xub,PETSC_INFINITY);
238: MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->ATA));
239: MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->Hx));
240: MatTransposeMatMult(user->A,user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &(user->HF));
242: MatAssemblyBegin(user->ATA,MAT_FINAL_ASSEMBLY);
243: MatAssemblyEnd(user->ATA,MAT_FINAL_ASSEMBLY);
244: MatAssemblyBegin(user->Hx,MAT_FINAL_ASSEMBLY);
245: MatAssemblyEnd(user->Hx,MAT_FINAL_ASSEMBLY);
246: MatAssemblyBegin(user->HF,MAT_FINAL_ASSEMBLY);
247: MatAssemblyEnd(user->HF,MAT_FINAL_ASSEMBLY);
249: user->lambda = 1.e-8;
250: user->eps = 1.e-3;
251: user->reg = 2;
252: user->mumin = 5.e-6;
254: PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c");
255: PetscOptionsInt("-reg","Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &(user->reg), NULL);
256: PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &(user->lambda), NULL);
257: PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &(user->eps), NULL);
258: PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &(user->mumin), NULL);
259: PetscOptionsEnd();
260: return 0;
261: }
263: /*------------------------------------------------------------*/
265: PetscErrorCode DestroyContext(AppCtx *user)
266: {
267: MatDestroy(&user->A);
268: MatDestroy(&user->ATA);
269: MatDestroy(&user->Hx);
270: MatDestroy(&user->Hz);
271: MatDestroy(&user->HF);
272: MatDestroy(&user->D);
273: MatDestroy(&user->DTD);
274: VecDestroy(&user->xGT);
275: VecDestroy(&user->xlb);
276: VecDestroy(&user->xub);
277: VecDestroy(&user->b);
278: VecDestroy(&user->x);
279: VecDestroy(&user->c);
280: VecDestroy(&user->workN3);
281: VecDestroy(&user->workN2);
282: VecDestroy(&user->workN);
283: VecDestroy(&user->workM);
284: return 0;
285: }
287: /*------------------------------------------------------------*/
289: int main(int argc,char **argv)
290: {
291: Tao tao,misfit,reg;
292: PetscReal v1,v2;
293: AppCtx* user;
294: PetscViewer fd;
295: char resultFile[] = "tomographyResult_x";
297: PetscInitialize(&argc,&argv,(char*)0,help);
298: PetscNew(&user);
299: InitializeUserData(user);
301: TaoCreate(PETSC_COMM_WORLD, &tao);
302: TaoSetType(tao, TAOADMM);
303: TaoSetSolution(tao, user->x);
304: /* f(x) + g(x) for parent tao */
305: TaoADMMSetSpectralPenalty(tao,1.);
306: TaoSetObjectiveAndGradient(tao,NULL, FullObjGrad, (void*)user);
307: MatShift(user->HF,user->lambda);
308: TaoSetHessian(tao, user->HF, user->HF, HessianFull, (void*)user);
310: /* f(x) for misfit tao */
311: TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void*)user);
312: TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void*)user);
313: TaoADMMSetMisfitHessianChangeStatus(tao,PETSC_FALSE);
314: TaoADMMSetMisfitConstraintJacobian(tao,user->D,user->D,NullJacobian,(void*)user);
316: /* g(x) for regularizer tao */
317: if (user->reg == 1) {
318: TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void*)user);
319: TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void*)user);
320: TaoADMMSetRegHessianChangeStatus(tao,PETSC_TRUE);
321: } else if (user->reg == 2) {
322: TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void*)user);
323: MatShift(user->Hz,1);
324: MatScale(user->Hz,user->lambda);
325: TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void*)user);
326: TaoADMMSetRegHessianChangeStatus(tao,PETSC_TRUE);
329: /* Set type for the misfit solver */
330: TaoADMMGetMisfitSubsolver(tao, &misfit);
331: TaoADMMGetRegularizationSubsolver(tao, ®);
332: TaoSetType(misfit,TAONLS);
333: if (user->reg == 3) {
334: TaoSetType(reg,TAOSHELL);
335: TaoShellSetContext(reg, (void*) user);
336: TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold);
337: } else {
338: TaoSetType(reg,TAONLS);
339: }
340: TaoSetVariableBounds(misfit,user->xlb,user->xub);
342: /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */
343: TaoADMMSetRegularizerCoefficient(tao, user->lambda);
344: TaoADMMSetRegularizerConstraintJacobian(tao,NULL,NULL,NullJacobian,(void*)user);
345: TaoADMMSetMinimumSpectralPenalty(tao,user->mumin);
347: TaoADMMSetConstraintVectorRHS(tao,user->c);
348: TaoSetFromOptions(tao);
349: TaoSolve(tao);
351: /* Save x (reconstruction of object) vector to a binary file, which maybe read from Matlab and convert to a 2D image for comparison. */
352: PetscViewerBinaryOpen(PETSC_COMM_WORLD,resultFile,FILE_MODE_WRITE,&fd);
353: VecView(user->x,fd);
354: PetscViewerDestroy(&fd);
356: /* compute the error */
357: VecAXPY(user->x,-1,user->xGT);
358: VecNorm(user->x,NORM_2,&v1);
359: VecNorm(user->xGT,NORM_2,&v2);
360: PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1/v2));
362: /* Free TAO data structures */
363: TaoDestroy(&tao);
364: DestroyContext(user);
365: PetscFree(user);
366: PetscFinalize();
367: return 0;
368: }
370: /*TEST
372: build:
373: requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)
375: test:
376: suffix: 1
377: localrunfiles: tomographyData_A_b_xGT
378: args: -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc
380: test:
381: suffix: 2
382: localrunfiles: tomographyData_A_b_xGT
383: args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor
385: test:
386: suffix: 3
387: localrunfiles: tomographyData_A_b_xGT
388: args: -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor
390: test:
391: suffix: 4
392: localrunfiles: tomographyData_A_b_xGT
393: args: -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc
395: test:
396: suffix: 5
397: localrunfiles: tomographyData_A_b_xGT
398: args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
400: test:
401: suffix: 6
402: localrunfiles: tomographyData_A_b_xGT
403: args: -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
405: TEST*/