Actual source code: taotermtest2.c
1: #include <petsctao.h>
3: static char help[] = "Using TaoTermShell with mapping matrices that are not diagonal.\n";
5: typedef struct {
6: Vec pdiff_work; /* Work vector for x - params */
7: } HalfL2Ctx;
9: typedef struct {
10: Mat A; /* Mapping matrix A */
11: Vec p; /* Target vector p */
12: Vec Ax; /* Work vector for A*x */
13: Vec Ax_p; /* Work vector for A*x - p */
14: } CallbackCtx;
16: static PetscErrorCode FormFunctionGradient(TaoTerm, Vec, Vec, PetscReal *, Vec);
17: static PetscErrorCode FormHessian(TaoTerm, Vec, Vec, Mat, Mat);
18: static PetscErrorCode CtxDestroy(PetscCtxRt ctx);
20: /* Callback functions for traditional TAO interface */
21: static PetscErrorCode FormObjectiveGradient_Callback(Tao, Vec, PetscReal *, Vec, void *);
22: static PetscErrorCode FormHessian_Callback(Tao, Vec, Mat, Mat, void *);
24: int main(int argc, char **argv)
25: {
26: TaoTerm objective;
27: Tao tao, tao2;
28: PetscMPIInt size;
29: HalfL2Ctx *ctx;
30: MPI_Comm comm;
31: PetscInt n = 10, m = 10;
32: Mat A;
33: Vec target;
34: CallbackCtx *cb_ctx;
35: Vec x_term, x_callback, x2, diff;
36: Mat H2;
37: PetscReal norm_diff, diag_val = 1.1;
38: PetscBool opt, is_diag, is_cdiag, is_aij, is_dense, fd_notpossible;
39: const char *mtype = MATAIJ;
40: char typeName[256] = "";
42: PetscFunctionBeginUser;
43: PetscCall(PetscInitialize(&argc, &argv, NULL, help));
44: comm = PETSC_COMM_WORLD;
45: PetscCallMPI(MPI_Comm_size(comm, &size));
46: PetscCheck(size == 1, comm, PETSC_ERR_WRONG_MPI_SIZE, "Incorrect number of processors");
48: fd_notpossible = PETSC_FALSE;
50: PetscOptionsBegin(comm, "", help, "none");
51: PetscCall(PetscOptionsBool("-fd_notpossible", "Set TaoTermShell ComputeHessianFDPossible as false", "", fd_notpossible, &fd_notpossible, NULL));
52: PetscCall(PetscOptionsInt("-n", "Problem size", "", n, &n, NULL));
53: PetscCall(PetscOptionsInt("-m", "Mapping matrix row size", "", m, &m, NULL));
54: PetscCall(PetscOptionsReal("-diag_val", "Value of constant diagonal matrix", NULL, diag_val, &diag_val, NULL));
55: PetscCall(PetscOptionsFList("-mapping_mtype", "Mapping matrix type", "", MatList, mtype, typeName, 256, &opt));
56: PetscOptionsEnd();
58: PetscCall(PetscNew(&ctx));
60: /* Initialize typeName to default if option was not set */
61: if (!opt) PetscCall(PetscStrcpy(typeName, mtype));
63: PetscCall(PetscStrcmp(typeName, MATDIAGONAL, &is_diag));
64: PetscCall(PetscStrcmp(typeName, MATCONSTANTDIAGONAL, &is_cdiag));
65: PetscCall(PetscStrcmp(typeName, MATAIJ, &is_aij));
66: PetscCall(PetscStrcmp(typeName, MATDENSE, &is_dense));
67: /* Create mapping matrix A: m x n (maps from solution space to term space) */
68: if (is_diag) {
69: /* Create a diagonal matrix */
70: Vec diag_vec;
71: PetscInt diag_size;
73: PetscCheck(m == n, comm, PETSC_ERR_ARG_INCOMP, "For diagonal matrix, m and n must be equal (got m=%" PetscInt_FMT ", n=%" PetscInt_FMT ")", m, n);
74: diag_size = m;
75: PetscCall(VecCreate(comm, &diag_vec));
76: PetscCall(VecSetSizes(diag_vec, PETSC_DECIDE, diag_size));
77: PetscCall(VecSetFromOptions(diag_vec));
78: PetscCall(VecSetRandom(diag_vec, NULL));
79: PetscCall(MatCreateDiagonal(diag_vec, &A));
80: PetscCall(VecDestroy(&diag_vec));
81: } else if (is_cdiag) {
82: /* Create a constant diagonal matrix */
83: PetscCheck(m == n, comm, PETSC_ERR_ARG_INCOMP, "For constant diagonal matrix, m and n must be equal (got m=%" PetscInt_FMT ", n=%" PetscInt_FMT ")", m, n);
84: PetscCall(MatCreateConstantDiagonal(comm, PETSC_DECIDE, PETSC_DECIDE, m, n, diag_val, &A));
85: } else if (is_dense) {
86: /* Create a dense matrix */
87: PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, m, n, NULL, &A));
88: PetscCall(MatSetFromOptions(A));
89: PetscCall(MatSetRandom(A, NULL));
90: PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
91: PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
92: } else {
93: /* Create an AIJ matrix (default) */
94: PetscCall(MatCreateSeqAIJ(comm, m, n, PETSC_DEFAULT, NULL, &A));
95: PetscCall(MatSetFromOptions(A));
96: PetscCall(MatSetRandom(A, NULL));
97: PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
98: PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
99: }
101: /* Create shell term that computes f(x) = 0.5 ||x||_2^2 */
102: PetscCall(TaoTermCreateShell(comm, ctx, CtxDestroy, &objective));
104: /* Set solution and parameter sizes to match the mapped space (m) */
105: PetscCall(TaoTermSetSolutionSizes(objective, PETSC_DECIDE, m, 1));
106: PetscCall(TaoTermSetParametersSizes(objective, PETSC_DECIDE, m, 1));
108: PetscCall(TaoTermShellSetObjectiveAndGradient(objective, FormFunctionGradient));
109: PetscCall(TaoTermShellSetCreateHessianMatrices(objective, TaoTermCreateHessianMatricesDefault));
110: PetscCall(TaoTermSetCreateHessianMode(objective, PETSC_TRUE /* H == Hpre */, MATAIJ, NULL));
111: PetscCall(TaoTermShellSetHessian(objective, FormHessian));
112: PetscCall(TaoTermSetFromOptions(objective));
113: if (fd_notpossible) PetscCall(TaoTermShellSetIsComputeHessianFDPossible(objective, PETSC_BOOL3_FALSE));
115: PetscCall(TaoTermSetUp(objective));
117: /* Create target vector for least squares problem (parameters) */
118: PetscCall(TaoTermCreateParametersVec(objective, &target));
119: PetscCall(VecSetRandom(target, NULL));
121: PetscCall(TaoCreate(comm, &tao));
122: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)tao, "shell_"));
123: PetscCall(TaoSetType(tao, TAOLMVM));
125: /* Add term with mapping matrix A: f(Ax; p) = 0.5 ||Ax - p||_2^2 */
126: PetscCall(TaoAddTerm(tao, NULL, 1.0, objective, target, A));
128: PetscCall(TaoSetFromOptions(tao));
129: PetscCall(TaoSolve(tao));
131: /* Allocate callback context */
132: PetscCall(PetscNew(&cb_ctx));
133: cb_ctx->A = A;
134: cb_ctx->p = target;
136: /* Create work vectors */
137: PetscCall(MatCreateVecs(A, NULL, &cb_ctx->Ax));
138: PetscCall(VecDuplicate(target, &cb_ctx->Ax_p));
140: PetscCall(MatCreateVecs(A, &x2, NULL));
141: PetscCall(VecZeroEntries(x2));
143: /* Create Hessian matrix A^T * A */
144: if (is_diag) {
145: Vec A_diag, H2_diag;
147: PetscCall(MatCreateVecs(A, &A_diag, NULL));
148: PetscCall(MatGetDiagonal(A, A_diag));
149: PetscCall(VecDuplicate(A_diag, &H2_diag));
150: PetscCall(VecPointwiseMult(H2_diag, A_diag, A_diag));
151: PetscCall(MatCreateDiagonal(H2_diag, &H2));
152: PetscCall(VecDestroy(&A_diag));
153: PetscCall(VecDestroy(&H2_diag));
154: } else if (is_cdiag) {
155: PetscCall(MatCreateConstantDiagonal(comm, PETSC_DECIDE, PETSC_DECIDE, m, n, diag_val * diag_val, &H2));
156: } else {
157: Mat Htest, Hpretest;
158: PetscBool is_h_dense;
160: PetscCall(MatTransposeMatMult(A, A, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &H2));
161: PetscCall(MatAssemblyBegin(H2, MAT_FINAL_ASSEMBLY));
162: PetscCall(MatAssemblyEnd(H2, MAT_FINAL_ASSEMBLY));
164: PetscCall(TaoGetHessianMatrices(tao, &Htest, &Hpretest));
165: PetscCall(PetscObjectBaseTypeCompare((PetscObject)Htest, MATSEQDENSE, &is_h_dense));
166: if (is_h_dense) PetscCall(MatConvert(H2, MATDENSE, MAT_INPLACE_MATRIX, &H2));
167: }
168: /* Create second TAO solver */
169: PetscCall(TaoCreate(comm, &tao2));
170: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)tao2, "regular_"));
171: PetscCall(TaoSetType(tao2, TAOLMVM));
172: PetscCall(TaoSetSolution(tao2, x2));
173: PetscCall(TaoSetObjectiveAndGradient(tao2, NULL, FormObjectiveGradient_Callback, cb_ctx));
174: PetscCall(TaoSetHessian(tao2, H2, H2, FormHessian_Callback, cb_ctx));
175: PetscCall(TaoSetFromOptions(tao2));
176: PetscCall(TaoSolve(tao2));
178: /* Compare solutions */
179: PetscCall(TaoGetSolution(tao, &x_term));
180: PetscCall(TaoGetSolution(tao2, &x_callback));
181: PetscCall(VecDuplicate(x_term, &diff));
182: PetscCall(VecCopy(x_term, diff));
183: PetscCall(VecAXPY(diff, -1.0, x_callback));
184: PetscCall(VecNorm(diff, NORM_2, &norm_diff));
185: if (norm_diff <= 1.e-12) PetscCall(PetscPrintf(comm, "Relative difference < 1e-12\n"));
186: else PetscCall(PetscPrintf(comm, "Relative difference > 1e-12: %6.10e\n", (double)norm_diff));
187: PetscCall(VecDestroy(&x2));
188: PetscCall(VecDestroy(&diff));
189: PetscCall(VecDestroy(&cb_ctx->Ax));
190: PetscCall(VecDestroy(&cb_ctx->Ax_p));
191: PetscCall(PetscFree(cb_ctx));
192: PetscCall(VecDestroy(&target));
193: PetscCall(MatDestroy(&A));
194: PetscCall(MatDestroy(&H2));
195: PetscCall(TaoDestroy(&tao2));
196: PetscCall(TaoDestroy(&tao));
197: PetscCall(TaoTermDestroy(&objective));
198: PetscCall(PetscFinalize());
199: return 0;
200: }
202: /*
203: FormFunctionGradient - Evaluates the function, f(X), and gradient, G(X).
205: Input Parameters:
206: + term - the `TaoTerm` for the objective function
207: . x - input vector
208: - params - optional vector of parameters
210: Output Parameters:
211: + f - function value
212: - G - vector containing the newly evaluated gradient
214: Note:
215: Computes f = 0.5 * ||x - params||_2^2 and g = x - params, matching TAOTERMHALFL2SQUARED.
216: */
217: static PetscErrorCode FormFunctionGradient(TaoTerm term, Vec x, Vec params, PetscReal *f, Vec G)
218: {
219: HalfL2Ctx *ctx;
220: PetscScalar v;
222: PetscFunctionBeginUser;
223: PetscCall(TaoTermShellGetContext(term, &ctx));
224: if (params) {
225: PetscCall(VecWAXPY(G, -1.0, params, x));
226: PetscCall(VecDot(G, G, &v));
227: } else {
228: PetscCall(VecCopy(x, G));
229: PetscCall(VecDot(G, G, &v));
230: }
231: *f = 0.5 * PetscRealPart(v);
232: PetscFunctionReturn(PETSC_SUCCESS);
233: }
235: /*
236: FormHessian - Evaluates Hessian matrix.
238: Input Parameters:
239: + term - the `TaoTerm` for the objective function
240: . x - input vector
241: . params - optional vector of parameters
242: - Hpre - optional preconditioner matrix
244: Output Parameters:
245: + H - Hessian matrix
246: - Hpre - Preconditioning matrix
248: Note:
249: Computes H = I (identity matrix), matching TAOTERMHALFL2SQUARED.
250: */
251: static PetscErrorCode FormHessian(TaoTerm term, Vec x, Vec params, Mat H, Mat Hpre)
252: {
253: PetscFunctionBeginUser;
254: if (H) {
255: PetscCall(MatZeroEntries(H));
256: PetscCall(MatAssemblyBegin(H, MAT_FINAL_ASSEMBLY));
257: PetscCall(MatAssemblyEnd(H, MAT_FINAL_ASSEMBLY));
258: PetscCall(MatShift(H, 1.0));
259: }
260: if (Hpre && Hpre != H) {
261: PetscCall(MatZeroEntries(Hpre));
262: PetscCall(MatAssemblyBegin(Hpre, MAT_FINAL_ASSEMBLY));
263: PetscCall(MatAssemblyEnd(Hpre, MAT_FINAL_ASSEMBLY));
264: PetscCall(MatShift(Hpre, 1.0));
265: }
266: PetscFunctionReturn(PETSC_SUCCESS);
267: }
269: static PetscErrorCode CtxDestroy(PetscCtxRt ctx_ptr)
270: {
271: HalfL2Ctx *ctx = *(HalfL2Ctx **)ctx_ptr;
273: PetscFunctionBeginUser;
274: if (ctx) {
275: PetscCall(VecDestroy(&ctx->pdiff_work));
276: PetscCall(PetscFree(ctx));
277: *(void **)ctx_ptr = NULL;
278: }
279: PetscFunctionReturn(PETSC_SUCCESS);
280: }
282: /*
283: FormObjectiveGradient_Callback - Evaluates the objective and gradient for traditional TAO callback interface.
285: Input Parameters:
286: + tao - the Tao solver context
287: . x - input vector (size n)
288: - ctx - user context containing A and p
290: Output Parameters:
291: + f - function value: 0.5 * ||Ax - p||_2^2
292: - g - gradient vector: A^T (Ax - p)
294: Note:
295: Computes f = 0.5 * ||Ax - p||_2^2 and g = A^T (Ax - p)
296: */
297: static PetscErrorCode FormObjectiveGradient_Callback(Tao tao, Vec x, PetscReal *f, Vec g, void *ctx)
298: {
299: CallbackCtx *cb_ctx = (CallbackCtx *)ctx;
300: PetscScalar v;
302: PetscFunctionBeginUser;
303: /* Compute Ax */
304: PetscCall(MatMult(cb_ctx->A, x, cb_ctx->Ax));
305: /* Compute Ax - p */
306: PetscCall(VecCopy(cb_ctx->Ax, cb_ctx->Ax_p));
307: PetscCall(VecAXPY(cb_ctx->Ax_p, -1.0, cb_ctx->p));
308: /* Compute objective: 0.5 * ||Ax - p||_2^2 */
309: PetscCall(VecDot(cb_ctx->Ax_p, cb_ctx->Ax_p, &v));
310: *f = 0.5 * PetscRealPart(v);
311: /* Compute gradient: A^T (Ax - p) */
312: PetscCall(MatMultTranspose(cb_ctx->A, cb_ctx->Ax_p, g));
313: PetscFunctionReturn(PETSC_SUCCESS);
314: }
316: /*
317: FormHessian_Callback - Evaluates the Hessian matrix for traditional TAO callback interface.
319: Input Parameters:
320: + tao - the Tao solver context
321: . x - input vector
322: . H - Hessian matrix (should be pre-allocated as A^T * A)
323: . Hpre - preconditioner matrix
324: - ctx - user context containing A and p
326: Output Parameters:
327: + H - Hessian matrix (A^T * A)
328: - Hpre - Preconditioning matrix
330: Note:
331: The Hessian for 0.5 * ||Ax - p||_2^2 is constant: H = A^T * A
332: */
333: static PetscErrorCode FormHessian_Callback(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx)
334: {
335: PetscFunctionBeginUser;
336: /* Hessian is constant: A^T * A, which should already be set in H */
337: if (Hpre && Hpre != H) PetscCall(MatCopy(H, Hpre, SAME_NONZERO_PATTERN));
338: PetscFunctionReturn(PETSC_SUCCESS);
339: }
341: /* Note: For dense variations, relative error may be greater than 1.e-12, *
342: * but that is okay, as it is a result of KSP, and PC using AIJ matrices *
343: * instead of dense. */
345: /*TEST
347: build:
348: requires: !complex !single !quad !defined(PETSC_USE_64BIT_INDICES) !__float128
350: test:
351: suffix: diag_diag
352: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
353: args: -tao_term_hessian_mat_type diagonal -mapping_mtype diagonal
355: test:
356: suffix: diag_cdiag
357: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
358: args: -tao_term_hessian_mat_type diagonal -mapping_mtype constantdiagonal
360: test:
361: suffix: diag_dense
362: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
363: args: -tao_term_hessian_mat_type diagonal -mapping_mtype dense
365: test:
366: suffix: diag_dense_nsq
367: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
368: args: -tao_term_hessian_mat_type diagonal -mapping_mtype dense -m 15
370: test:
371: suffix: diag_aij
372: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
373: args: -tao_term_hessian_mat_type diagonal -mapping_mtype aij
375: test:
376: suffix: cdiag_diag
377: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
378: args: -tao_term_hessian_mat_type constantdiagonal -mapping_mtype diagonal
380: test:
381: suffix: cdiag_cdiag
382: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
383: args: -tao_term_hessian_mat_type constantdiagonal -mapping_mtype constantdiagonal
385: test:
386: suffix: cdiag_dense
387: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
388: args: -tao_term_hessian_mat_type constantdiagonal -mapping_mtype dense
390: test:
391: suffix: cdiag_dense_nsq
392: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
393: args: -tao_term_hessian_mat_type constantdiagonal -mapping_mtype dense -m 15
395: test:
396: suffix: cdiag_aij
397: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
398: args: -tao_term_hessian_mat_type constantdiagonal -mapping_mtype aij
400: test:
401: suffix: dense_diag
402: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
403: args: -tao_term_hessian_mat_type dense -mapping_mtype diagonal
405: test:
406: suffix: dense_cdiag
407: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
408: args: -tao_term_hessian_mat_type dense -mapping_mtype constantdiagonal
410: test:
411: suffix: dense_dense
412: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
413: args: -tao_term_hessian_mat_type dense -mapping_mtype dense -fd_notpossible {{0 1}}
415: test:
416: suffix: dense_dense_nsq
417: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
418: args: -tao_term_hessian_mat_type dense -mapping_mtype dense -m 15
420: test:
421: suffix: dense_aij
422: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
423: args: -tao_term_hessian_mat_type dense -mapping_mtype aij
425: test:
426: suffix: aij_diag
427: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
428: args: -tao_term_hessian_mat_type aij -mapping_mtype diagonal
430: test:
431: suffix: aij_cdiag
432: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
433: args: -tao_term_hessian_mat_type aij -mapping_mtype constantdiagonal
435: test:
436: suffix: aij_dense
437: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
438: args: -tao_term_hessian_mat_type aij -mapping_mtype dense
440: test:
441: suffix: aij_dense_nsq
442: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
443: args: -tao_term_hessian_mat_type aij -mapping_mtype dense -m 15
445: test:
446: suffix: aij_aij
447: args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
448: args: -tao_term_hessian_mat_type aij -mapping_mtype aij
450: TEST*/