Actual source code: ex23fwdadj.c
1: static char help[] = "A toy example for testing forward and adjoint sensitivity analysis of an implicit ODE with a parametrized mass matrice.\n";
3: /*
4: This example solves the simple ODE
5: c x' = b x, x(0) = a,
6: whose analytical solution is x(T)=a*exp(b/c*T), and calculates the derivative of x(T) w.r.t. c (by default) or w.r.t. b (can be enabled with command line option -der 2).
8: */
10: #include <petscts.h>
12: typedef struct _n_User *User;
13: struct _n_User {
14: PetscReal a;
15: PetscReal b;
16: PetscReal c;
17: /* Sensitivity analysis support */
18: PetscInt steps;
19: PetscReal ftime;
20: Mat Jac; /* Jacobian matrix */
21: Mat Jacp; /* JacobianP matrix */
22: Vec x;
23: Mat sp; /* forward sensitivity variables */
24: Vec lambda[1]; /* adjoint sensitivity variables */
25: Vec mup[1]; /* adjoint sensitivity variables */
26: PetscInt der;
27: };
29: static PetscErrorCode IFunction(TS ts, PetscReal t, Vec X, Vec Xdot, Vec F, void *ctx)
30: {
31: User user = (User)ctx;
32: const PetscScalar *x, *xdot;
33: PetscScalar *f;
35: PetscFunctionBeginUser;
36: PetscCall(VecGetArrayRead(X, &x));
37: PetscCall(VecGetArrayRead(Xdot, &xdot));
38: PetscCall(VecGetArrayWrite(F, &f));
39: f[0] = user->c * xdot[0] - user->b * x[0];
40: PetscCall(VecRestoreArrayRead(X, &x));
41: PetscCall(VecRestoreArrayRead(Xdot, &xdot));
42: PetscCall(VecRestoreArrayWrite(F, &f));
43: PetscFunctionReturn(PETSC_SUCCESS);
44: }
46: static PetscErrorCode IJacobian(TS ts, PetscReal t, Vec X, Vec Xdot, PetscReal a, Mat A, Mat B, void *ctx)
47: {
48: User user = (User)ctx;
49: PetscInt rowcol[] = {0};
50: PetscScalar J[1][1];
51: const PetscScalar *x;
53: PetscFunctionBeginUser;
54: PetscCall(VecGetArrayRead(X, &x));
55: J[0][0] = user->c * a - user->b * 1.0;
56: PetscCall(MatSetValues(B, 1, rowcol, 1, rowcol, &J[0][0], INSERT_VALUES));
57: PetscCall(VecRestoreArrayRead(X, &x));
59: PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
60: PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
61: if (A != B) {
62: PetscCall(MatAssemblyBegin(B, MAT_FINAL_ASSEMBLY));
63: PetscCall(MatAssemblyEnd(B, MAT_FINAL_ASSEMBLY));
64: }
65: PetscFunctionReturn(PETSC_SUCCESS);
66: }
68: static PetscErrorCode IJacobianP(TS ts, PetscReal t, Vec X, Vec Xdot, PetscReal shift, Mat A, void *ctx)
69: {
70: User user = (User)ctx;
71: PetscInt row[] = {0}, col[] = {0};
72: PetscScalar J[1][1];
73: const PetscScalar *x, *xdot;
74: PetscReal dt;
76: PetscFunctionBeginUser;
77: PetscCall(VecGetArrayRead(X, &x));
78: PetscCall(VecGetArrayRead(Xdot, &xdot));
79: PetscCall(TSGetTimeStep(ts, &dt));
80: if (user->der == 1) J[0][0] = xdot[0];
81: if (user->der == 2) J[0][0] = -x[0];
82: PetscCall(MatSetValues(A, 1, row, 1, col, &J[0][0], INSERT_VALUES));
83: PetscCall(VecRestoreArrayRead(X, &x));
85: PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
86: PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
87: PetscFunctionReturn(PETSC_SUCCESS);
88: }
90: int main(int argc, char **argv)
91: {
92: TS ts;
93: PetscScalar *x_ptr;
94: PetscMPIInt size;
95: struct _n_User user;
96: PetscInt rows, cols;
98: PetscFunctionBeginUser;
99: PetscCall(PetscInitialize(&argc, &argv, NULL, help));
101: PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
102: PetscCheck(size == 1, PETSC_COMM_WORLD, PETSC_ERR_WRONG_MPI_SIZE, "This is a uniprocessor example only!");
104: user.a = 2.0;
105: user.b = 4.0;
106: user.c = 3.0;
107: user.steps = 0;
108: user.ftime = 1.0;
109: user.der = 1;
110: PetscCall(PetscOptionsGetInt(NULL, NULL, "-der", &user.der, NULL));
112: rows = 1;
113: cols = 1;
114: PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jac));
115: PetscCall(MatSetSizes(user.Jac, PETSC_DECIDE, PETSC_DECIDE, 1, 1));
116: PetscCall(MatSetFromOptions(user.Jac));
117: PetscCall(MatSetUp(user.Jac));
118: PetscCall(MatCreateVecs(user.Jac, &user.x, NULL));
120: PetscCall(TSCreate(PETSC_COMM_WORLD, &ts));
121: PetscCall(TSSetType(ts, TSBEULER));
122: PetscCall(TSSetIFunction(ts, NULL, IFunction, &user));
123: PetscCall(TSSetIJacobian(ts, user.Jac, user.Jac, IJacobian, &user));
124: PetscCall(TSSetExactFinalTime(ts, TS_EXACTFINALTIME_MATCHSTEP));
125: PetscCall(TSSetMaxTime(ts, user.ftime));
127: PetscCall(VecGetArrayWrite(user.x, &x_ptr));
128: x_ptr[0] = user.a;
129: PetscCall(VecRestoreArrayWrite(user.x, &x_ptr));
130: PetscCall(TSSetTimeStep(ts, 0.001));
132: /* Set up forward sensitivity */
133: PetscCall(MatCreate(PETSC_COMM_WORLD, &user.Jacp));
134: PetscCall(MatSetSizes(user.Jacp, PETSC_DECIDE, PETSC_DECIDE, rows, cols));
135: PetscCall(MatSetFromOptions(user.Jacp));
136: PetscCall(MatSetUp(user.Jacp));
137: PetscCall(MatCreateDense(PETSC_COMM_WORLD, PETSC_DECIDE, PETSC_DECIDE, rows, cols, NULL, &user.sp));
138: PetscCall(MatZeroEntries(user.sp));
139: PetscCall(TSForwardSetSensitivities(ts, cols, user.sp));
140: PetscCall(TSSetIJacobianP(ts, user.Jacp, IJacobianP, &user));
142: PetscCall(TSSetSaveTrajectory(ts));
143: PetscCall(TSSetFromOptions(ts));
145: PetscCall(TSSolve(ts, user.x));
146: PetscCall(TSGetSolveTime(ts, &user.ftime));
147: PetscCall(TSGetStepNumber(ts, &user.steps));
148: PetscCall(VecGetArray(user.x, &x_ptr));
149: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n ode solution %g\n", (double)PetscRealPart(x_ptr[0])));
150: PetscCall(VecRestoreArray(user.x, &x_ptr));
151: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical solution %g\n", (double)(user.a * PetscExpReal(user.b / user.c * user.ftime))));
153: if (user.der == 1) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical derivative w.r.t. c %g\n", (double)(-user.a * user.ftime * user.b / (user.c * user.c) * PetscExpReal(user.b / user.c * user.ftime))));
154: if (user.der == 2) PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n analytical derivative w.r.t. b %g\n", (double)(user.a * user.ftime / user.c * PetscExpReal(user.b / user.c * user.ftime))));
155: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n forward sensitivity:\n"));
156: PetscCall(MatView(user.sp, PETSC_VIEWER_STDOUT_WORLD));
158: PetscCall(MatCreateVecs(user.Jac, &user.lambda[0], NULL));
159: /* Set initial conditions for the adjoint integration */
160: PetscCall(VecGetArrayWrite(user.lambda[0], &x_ptr));
161: x_ptr[0] = 1.0;
162: PetscCall(VecRestoreArrayWrite(user.lambda[0], &x_ptr));
163: PetscCall(MatCreateVecs(user.Jacp, &user.mup[0], NULL));
164: PetscCall(VecGetArrayWrite(user.mup[0], &x_ptr));
165: x_ptr[0] = 0.0;
166: PetscCall(VecRestoreArrayWrite(user.mup[0], &x_ptr));
168: PetscCall(TSSetCostGradients(ts, 1, user.lambda, user.mup));
169: PetscCall(TSAdjointSolve(ts));
171: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "\n adjoint sensitivity:\n"));
172: PetscCall(VecView(user.mup[0], PETSC_VIEWER_STDOUT_WORLD));
174: PetscCall(MatDestroy(&user.Jac));
175: PetscCall(MatDestroy(&user.sp));
176: PetscCall(MatDestroy(&user.Jacp));
177: PetscCall(VecDestroy(&user.x));
178: PetscCall(VecDestroy(&user.lambda[0]));
179: PetscCall(VecDestroy(&user.mup[0]));
180: PetscCall(TSDestroy(&ts));
182: PetscCall(PetscFinalize());
183: return 0;
184: }
186: /*TEST
188: test:
189: args: -ts_type beuler
191: test:
192: suffix: 2
193: args: -ts_type cn
194: output_file: output/ex23fwdadj_1.out
196: TEST*/