Actual source code: matdiagonalkokkos.kokkos.cxx

  1: #include <petscvec_kokkos.hpp>
  2: #include <petsc_kokkos.hpp>
  3: #include <petsc/private/kokkosimpl.hpp>
  4: #include <petsc/private/vecimpl.h>
  5: #include <petsc/private/matimpl.h>

  7: PETSC_INTERN PetscErrorCode MatADot_Diagonal_SeqKokkos(Mat A, Vec x, Vec y, PetscScalar *z)
  8: {
  9:   Mat_Diagonal              *ctx = (Mat_Diagonal *)A->data;
 10:   ConstPetscScalarKokkosView xv, yv, wv;

 12:   PetscFunctionBegin;
 13:   PetscCall(PetscLogGpuTimeBegin());
 14:   PetscCall(VecGetKokkosView(x, &xv));
 15:   PetscCall(VecGetKokkosView(y, &yv));
 16:   PetscCall(VecGetKokkosView(ctx->diag, &wv));
 17:   // Kokkos always overwrites z, so no need to init it
 18:   PetscCallCXX(Kokkos::parallel_reduce("MatADot_Diagonal", Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, x->map->n), KOKKOS_LAMBDA(const PetscInt &i, PetscScalar &update) { update += PetscConj(yv(i)) * wv(i) * xv(i); }, *z));
 19:   PetscCall(VecRestoreKokkosView(x, &xv));
 20:   PetscCall(VecRestoreKokkosView(y, &yv));
 21:   PetscCall(VecRestoreKokkosView(ctx->diag, &wv));
 22:   PetscCall(PetscLogGpuTimeEnd());
 23:   if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
 24:   PetscFunctionReturn(PETSC_SUCCESS);
 25: }

 27: PETSC_INTERN PetscErrorCode MatANormSq_Diagonal_SeqKokkos(Mat A, Vec x, PetscReal *z)
 28: {
 29:   Mat_Diagonal              *ctx = (Mat_Diagonal *)A->data;
 30:   ConstPetscScalarKokkosView xv, wv;
 31:   PetscScalar                res = 0.;

 33:   PetscFunctionBegin;
 34:   PetscCall(PetscLogGpuTimeBegin());
 35:   PetscCall(VecGetKokkosView(x, &xv));
 36:   PetscCall(VecGetKokkosView(ctx->diag, &wv));
 37:   PetscCallCXX(Kokkos::parallel_reduce("MatANorm_Diagonal", Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, x->map->n), KOKKOS_LAMBDA(const PetscInt &i, PetscScalar &update) { update += PetscConj(xv(i)) * wv(i) * xv(i); }, res));
 38:   PetscCall(VecRestoreKokkosView(x, &xv));
 39:   PetscCall(VecRestoreKokkosView(ctx->diag, &wv));
 40:   PetscCall(PetscLogGpuTimeEnd());
 41:   *z = PetscRealPart(res);
 42:   if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
 43:   PetscFunctionReturn(PETSC_SUCCESS);
 44: }