Actual source code: matdiagonalcupm.hpp
1: #pragma once
3: #include <petscmat.h>
5: #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
7: #include <petsc/private/cupminterface.hpp>
8: #include <petsc/private/cupmobject.hpp>
9: #include <petsc/private/deviceimpl.h>
10: #include <petsc/private/vecimpl.h>
11: #include <petsc/private/veccupmimpl.h>
12: #include <petsc/private/matimpl.h>
14: #include <thrust/device_ptr.h>
15: #include <thrust/iterator/zip_iterator.h>
16: #include <thrust/transform_reduce.h>
18: namespace Petsc
19: {
21: namespace device
22: {
24: namespace cupm
25: {
27: namespace impl
28: {
30: template <DeviceType T, typename VecType>
31: struct MatDiagonal_CUPM : vec::cupm::impl::Vec_CUPMBase<T, VecType> {
32: PETSC_CUPMOBJECT_HEADER(T);
33: using base_type = ::Petsc::vec::cupm::impl::Vec_CUPMBase<T, VecType>;
34: friend base_type;
36: static PetscErrorCode ADot(Mat A, Vec x, Vec y, PetscScalar *z) noexcept;
37: static PetscErrorCode ANormSq(Mat A, Vec x, PetscReal *z) noexcept;
38: };
40: namespace detail
41: {
42: struct adot_transform {
43: using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar>;
45: PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const argument_type &tup) const noexcept { return PetscConj(thrust::get<1>(tup)) * thrust::get<2>(tup) * thrust::get<0>(tup); }
46: };
47: } // namespace detail
49: template <Petsc::device::cupm::DeviceType T, typename VecType>
50: inline PetscErrorCode MatDiagonal_CUPM<T, VecType>::ADot(Mat A, Vec x, Vec y, PetscScalar *z) noexcept
51: {
52: PetscDeviceContext dctx;
53: cupmStream_t stream;
54: Mat_Diagonal *ctx = (Mat_Diagonal *)A->data;
55: PetscScalar zero = 0.;
56: const PetscInt n = x->map->n;
58: PetscFunctionBegin;
59: PetscCall(GetHandles_(&dctx, &stream));
61: const auto xdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, x).data());
62: const auto ydptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, y).data());
63: const auto wdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, ctx->diag).data());
65: // clang-format off
66: PetscCallThrust(
67: *z = THRUST_CALL(
68: thrust::transform_reduce,
69: stream,
70: thrust::make_zip_iterator(thrust::make_tuple(xdptr, ydptr, wdptr)),
71: thrust::make_zip_iterator(thrust::make_tuple(xdptr + n, ydptr + n, wdptr + n)),
72: detail::adot_transform{},
73: zero,
74: thrust::plus<PetscScalar>()
75: )
76: );
77: // clang-format on
78: if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
79: PetscFunctionReturn(PETSC_SUCCESS);
80: }
82: namespace detail
83: {
84: struct anorm_transform {
85: using argument_type = thrust::tuple<PetscScalar, PetscScalar>;
87: PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const argument_type &tup) const noexcept { return thrust::get<1>(tup) * PetscConj(thrust::get<0>(tup)) * thrust::get<0>(tup); }
88: };
89: } // namespace detail
91: template <Petsc::device::cupm::DeviceType T, typename VecType>
92: inline PetscErrorCode MatDiagonal_CUPM<T, VecType>::ANormSq(Mat A, Vec x, PetscReal *z) noexcept
93: {
94: PetscDeviceContext dctx;
95: cupmStream_t stream;
96: Mat_Diagonal *ctx = (Mat_Diagonal *)A->data;
97: PetscScalar zero = 0., res;
98: const PetscInt n = x->map->n;
100: PetscFunctionBegin;
101: PetscCall(GetHandles_(&dctx, &stream));
103: const auto xdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, x).data());
104: const auto wdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, ctx->diag).data());
106: // clang-format off
107: PetscCallThrust(
108: res = THRUST_CALL(
109: thrust::transform_reduce,
110: stream,
111: thrust::make_zip_iterator(thrust::make_tuple(xdptr, wdptr)),
112: thrust::make_zip_iterator(thrust::make_tuple(xdptr + n, wdptr + n)),
113: detail::anorm_transform{},
114: zero,
115: thrust::plus<PetscScalar>()
116: )
117: );
118: // clang-format on
119: *z = PetscRealPart(res);
120: if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
121: PetscFunctionReturn(PETSC_SUCCESS);
122: }
124: } // namespace impl
126: } // namespace cupm
128: } // namespace device
130: } // namespace Petsc