Actual source code: solve_performance.c

  1: const char help[] = "Profile the performance of MATLMVM MatSolve() in a loop";

  3: #include <petscksp.h>
  4: #include <petscmath.h>

  6: int main(int argc, char **argv)
  7: {
  8:   PetscInt      n        = 1000;
  9:   PetscInt      n_epochs = 10;
 10:   PetscInt      n_iters  = 10;
 11:   Vec           x, g, dx, df, p;
 12:   PetscRandom   rand;
 13:   PetscLogStage matsolve_loop, main_stage;
 14:   Mat           B, J0;

 16:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
 17:   PetscCall(KSPInitializePackage());
 18:   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, help, "KSP");
 19:   PetscCall(PetscOptionsInt("-n", "Vector size", __FILE__, n, &n, NULL));
 20:   PetscCall(PetscOptionsInt("-epochs", "Number of epochs", __FILE__, n_epochs, &n_epochs, NULL));
 21:   PetscCall(PetscOptionsInt("-iters", "Number of iterations per epoch", __FILE__, n_iters, &n_iters, NULL));
 22:   PetscOptionsEnd();
 23:   PetscCall(VecCreateMPI(PETSC_COMM_WORLD, PETSC_DETERMINE, n, &x));
 24:   PetscCall(VecSetFromOptions(x));
 25:   PetscCall(VecDuplicate(x, &g));
 26:   PetscCall(VecDuplicate(x, &dx));
 27:   PetscCall(VecDuplicate(x, &df));
 28:   PetscCall(VecDuplicate(x, &p));
 29:   PetscCall(MatCreate(PETSC_COMM_WORLD, &B));
 30:   PetscCall(MatSetType(B, MATLMVMBFGS));
 31:   PetscCall(MatLMVMAllocate(B, x, g));
 32:   PetscCall(MatSetFromOptions(B));
 33:   PetscCall(MatLMVMGetJ0(B, &J0));
 34:   PetscCall(MatZeroEntries(J0));
 35:   PetscCall(MatShift(J0, 1.0));
 36:   PetscCall(PetscRandomCreate(PETSC_COMM_WORLD, &rand));
 37:   PetscCall(PetscRandomSetInterval(rand, -1.0, 1.0));
 38:   PetscCall(PetscRandomSetFromOptions(rand));
 39:   PetscCall(PetscLogStageRegister("LMVM MatSolve Loop", &matsolve_loop));
 40:   PetscCall(PetscLogStageGetId("Main Stage", &main_stage));
 41:   PetscCall(PetscLogStageSetVisible(main_stage, PETSC_FALSE));
 42:   for (PetscInt epoch = 0; epoch < n_epochs + 1; epoch++) {
 43:     PetscScalar dot;
 44:     PetscReal   xscale, fscale, absdot;
 45:     PetscInt    history_size;

 47:     PetscCall(VecSetRandom(dx, rand));
 48:     PetscCall(VecSetRandom(df, rand));
 49:     PetscCall(VecDot(dx, df, &dot));
 50:     absdot = PetscAbsScalar(dot);
 51:     PetscCall(VecSetRandom(x, rand));
 52:     PetscCall(VecSetRandom(g, rand));
 53:     xscale = 1.0;
 54:     fscale = absdot / PetscRealPart(dot);
 55:     PetscCall(MatLMVMGetHistorySize(B, &history_size));

 57:     PetscCall(MatLMVMUpdate(B, x, g));
 58:     for (PetscInt iter = 0; iter < history_size; iter++, xscale *= -1.0, fscale *= -1.0) {
 59:       PetscCall(VecAXPY(x, xscale, dx));
 60:       PetscCall(VecAXPY(g, fscale, df));
 61:       PetscCall(MatLMVMUpdate(B, x, g));
 62:       PetscCall(MatSolve(B, g, p));
 63:     }
 64:     if (epoch > 0) PetscCall(PetscLogStagePush(matsolve_loop));
 65:     for (PetscInt iter = 0; iter < n_iters; iter++, xscale *= -1.0, fscale *= -1.0) {
 66:       PetscCall(VecAXPY(x, xscale, dx));
 67:       PetscCall(VecAXPY(g, fscale, df));
 68:       PetscCall(MatLMVMUpdate(B, x, g));
 69:       PetscCall(MatSolve(B, g, p));
 70:     }
 71:     PetscCall(MatLMVMReset(B, PETSC_FALSE));
 72:     if (epoch > 0) PetscCall(PetscLogStagePop());
 73:   }
 74:   PetscCall(PetscViewerPushFormat(PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD), PETSC_VIEWER_ASCII_INFO_DETAIL));
 75:   PetscCall(MatView(B, PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD)));
 76:   PetscCall(PetscViewerPopFormat(PETSC_VIEWER_STDOUT_(PETSC_COMM_WORLD)));
 77:   PetscCall(PetscRandomDestroy(&rand));
 78:   PetscCall(MatDestroy(&B));
 79:   PetscCall(VecDestroy(&p));
 80:   PetscCall(VecDestroy(&df));
 81:   PetscCall(VecDestroy(&dx));
 82:   PetscCall(VecDestroy(&g));
 83:   PetscCall(VecDestroy(&x));
 84:   PetscCall(PetscFinalize());
 85:   return 0;
 86: }

 88: /*TEST

 90:   test:
 91:     suffix: 0
 92:     args: -mat_lmvm_scale_type none

 94: TEST*/