Actual source code: solve_performance.c
1: const char help[] = "Profile the performance of MATLMVM MatSolve() in a loop";
3: #include <petscksp.h>
4: #include <petscmath.h>
6: int main(int argc, char **argv)
7: {
8: PetscInt n = 1000;
9: PetscInt n_epochs = 10;
10: PetscInt n_iters = 10;
11: Vec x, g, dx, df, p;
12: PetscRandom rand;
13: PetscLogStage matsolve_loop, main_stage;
14: Mat B, J0;
16: PetscCall(PetscInitialize(&argc, &argv, NULL, help));
17: PetscCall(KSPInitializePackage());
18: PetscOptionsBegin(PETSC_COMM_WORLD, NULL, help, "KSP");
19: PetscCall(PetscOptionsInt("-n", "Vector size", __FILE__, n, &n, NULL));
20: PetscCall(PetscOptionsInt("-epochs", "Number of epochs", __FILE__, n_epochs, &n_epochs, NULL));
21: PetscCall(PetscOptionsInt("-iters", "Number of iterations per epoch", __FILE__, n_iters, &n_iters, NULL));
22: PetscOptionsEnd();
23: PetscCall(VecCreateMPI(PETSC_COMM_WORLD, PETSC_DETERMINE, n, &x));
24: PetscCall(VecSetFromOptions(x));
25: PetscCall(VecDuplicate(x, &g));
26: PetscCall(VecDuplicate(x, &dx));
27: PetscCall(VecDuplicate(x, &df));
28: PetscCall(VecDuplicate(x, &p));
29: PetscCall(MatCreate(PETSC_COMM_WORLD, &B));
30: PetscCall(MatSetType(B, MATLMVMBFGS));
31: PetscCall(MatLMVMAllocate(B, x, g));
32: PetscCall(MatSetFromOptions(B));
33: PetscCall(MatLMVMGetJ0(B, &J0));
34: PetscCall(MatZeroEntries(J0));
35: PetscCall(MatShift(J0, 1.0));
36: PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &rand));
37: PetscCall(PetscRandomSetInterval(rand, -1.0, 1.0));
38: PetscCall(PetscRandomSetFromOptions(rand));
39: PetscCall(PetscLogStageRegister("LMVM MatSolve Loop", &matsolve_loop));
40: PetscCall(PetscLogStageGetId("Main Stage", &main_stage));
41: PetscCall(PetscLogStageSetVisible(main_stage, PETSC_FALSE));
42: for (PetscInt epoch = 0; epoch < n_epochs + 1; epoch++) {
43: PetscScalar dot;
44: PetscReal xscale, fscale, absdot;
45: PetscInt history_size;
47: PetscCall(VecSetRandom(dx, rand));
48: PetscCall(VecSetRandom(df, rand));
49: PetscCall(VecDot(dx, df, &dot));
50: absdot = PetscAbsScalar(dot);
51: PetscCall(VecSetRandom(x, rand));
52: PetscCall(VecSetRandom(g, rand));
53: xscale = 1.0;
54: fscale = absdot / PetscRealPart(dot);
55: PetscCall(MatLMVMGetHistorySize(B, &history_size));
57: PetscCall(MatLMVMUpdate(B, x, g));
58: for (PetscInt iter = 0; iter < history_size; iter++, xscale *= -1.0, fscale *= -1.0) {
59: PetscCall(VecAXPY(x, xscale, dx));
60: PetscCall(VecAXPY(g, fscale, df));
61: PetscCall(MatLMVMUpdate(B, x, g));
62: PetscCall(MatSolve(B, g, p));
63: }
64: if (epoch > 0) PetscCall(PetscLogStagePush(matsolve_loop));
65: for (PetscInt iter = 0; iter < n_iters; iter++, xscale *= -1.0, fscale *= -1.0) {
66: PetscCall(VecAXPY(x, xscale, dx));
67: PetscCall(VecAXPY(g, fscale, df));
68: PetscCall(MatLMVMUpdate(B, x, g));
69: PetscCall(MatSolve(B, g, p));
70: }
71: PetscCall(MatLMVMReset(B, PETSC_FALSE));
72: if (epoch > 0) PetscCall(PetscLogStagePop());
73: }
74: PetscCall(PetscViewerPushFormat(PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD), PETSC_VIEWER_ASCII_INFO_DETAIL));
75: PetscCall(MatView(B, PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD)));
76: PetscCall(PetscViewerPopFormat(PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD)));
77: PetscCall(PetscRandomDestroy(&rand));
78: PetscCall(MatDestroy(&B));
79: PetscCall(VecDestroy(&p));
80: PetscCall(VecDestroy(&df));
81: PetscCall(VecDestroy(&dx));
82: PetscCall(VecDestroy(&g));
83: PetscCall(VecDestroy(&x));
84: PetscCall(PetscFinalize());
85: return 0;
86: }
88: /*TEST
90: test:
91: suffix: 0
92: args: -mat_lmvm_scale_type none
94: TEST*/