Actual source code: cd_cupm.cxx
1: #include "../denseqn.h"
2: #include <petsc/private/cupminterface.hpp>
3: #include <petsc/private/cupmobject.hpp>
5: namespace Petsc
6: {
8: namespace device
9: {
11: namespace cupm
12: {
14: namespace impl
15: {
17: template <DeviceType T>
18: struct UpperTriangular : CUPMObject<T> {
19: PETSC_CUPMOBJECT_HEADER(T);
21: static PetscErrorCode SolveInPlace(PetscDeviceContext, PetscBool, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
22: static PetscErrorCode SolveInPlaceCyclic(PetscDeviceContext, PetscBool, PetscInt, PetscInt, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
23: };
25: template <DeviceType T>
26: PetscErrorCode UpperTriangular<T>::SolveInPlace(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt N, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
27: {
28: cupmBlasInt_t n;
29: cupmBlasHandle_t handle;
30: auto A_ = cupmScalarPtrCast(A);
31: auto x_ = cupmScalarPtrCast(x);
33: PetscFunctionBegin;
34: if (!N) PetscFunctionReturn(PETSC_SUCCESS);
35: PetscCall(PetscCUPMBlasIntCast(N, &n));
36: PetscCall(GetHandlesFrom_(dctx, &handle));
37: PetscCall(PetscLogGpuTimeBegin());
38: PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, hermitian_transpose ? CUPMBLAS_OP_C : CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n, A_, lda, x_, stride));
39: PetscCall(PetscLogGpuTimeEnd());
41: PetscCall(PetscLogGpuFlops(1.0 * N * N));
42: PetscFunctionReturn(PETSC_SUCCESS);
43: }
45: template <DeviceType T>
46: PetscErrorCode UpperTriangular<T>::SolveInPlaceCyclic(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
47: {
48: PetscInt N = next - oldest;
49: PetscInt oldest_index = oldest % m;
50: PetscInt next_index = next % m;
51: cupmBlasInt_t n_old, n_new;
52: cupmBlasPointerMode_t pointer_mode;
53: cupmBlasHandle_t handle;
54: auto sone = cupmScalarCast(1.0);
55: auto minus_one = cupmScalarCast(-1.0);
56: auto A_ = cupmScalarPtrCast(A);
57: auto x_ = cupmScalarPtrCast(x);
59: PetscFunctionBegin;
60: if (!N) PetscFunctionReturn(PETSC_SUCCESS);
61: PetscCall(PetscCUPMBlasIntCast(m - oldest_index, &n_old));
62: PetscCall(PetscCUPMBlasIntCast(next_index, &n_new));
63: PetscCall(GetHandlesFrom_(dctx, &handle));
64: PetscCall(PetscLogGpuTimeBegin());
65: PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
66: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
67: if (!hermitian_transpose) {
68: if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_new, A_, lda, x_, stride));
69: if (n_new > 0 && n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_N, n_old, n_new, &minus_one, &A_[oldest_index], lda, x_, stride, &sone, &x_[oldest_index], stride));
70: if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_old, &A_[oldest_index * (lda + 1)], lda, &x_[oldest_index], stride));
71: } else {
72: if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_old, &A_[oldest_index * (lda + 1)], lda, &x_[oldest_index], stride));
73: if (n_new > 0 && n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_C, n_old, n_new, &minus_one, &A_[oldest_index], lda, &x_[oldest_index], stride, &sone, x_, stride));
74: if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_new, A_, lda, x_, stride));
75: }
76: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
77: PetscCall(PetscLogGpuTimeEnd());
79: PetscCall(PetscLogGpuFlops(1.0 * N * N));
80: PetscFunctionReturn(PETSC_SUCCESS);
81: }
83: #if PetscDefined(HAVE_CUDA)
84: template struct UpperTriangular<DeviceType::CUDA>;
85: #endif
87: #if PetscDefined(HAVE_HIP)
88: template struct UpperTriangular<DeviceType::HIP>;
89: #endif
91: } // namespace impl
93: } // namespace cupm
95: } // namespace device
97: } // namespace Petsc
99: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_CUPM(PetscBool hermitian_transpose, PetscInt n, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
100: {
101: using ::Petsc::device::cupm::impl::UpperTriangular;
102: using ::Petsc::device::cupm::DeviceType;
103: PetscDeviceContext dctx;
104: PetscDeviceType device_type;
106: PetscFunctionBegin;
107: PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
108: PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
109: switch (device_type) {
110: #if PetscDefined(HAVE_CUDA)
111: case PETSC_DEVICE_CUDA:
112: PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
113: break;
114: #endif
115: #if PetscDefined(HAVE_HIP)
116: case PETSC_DEVICE_HIP:
117: PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
118: break;
119: #endif
120: default:
121: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
122: }
123: PetscFunctionReturn(PETSC_SUCCESS);
124: }
126: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlaceCyclic_CUPM(PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
127: {
128: using ::Petsc::device::cupm::impl::UpperTriangular;
129: using ::Petsc::device::cupm::DeviceType;
130: PetscDeviceContext dctx;
131: PetscDeviceType device_type;
133: PetscFunctionBegin;
134: PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
135: PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
136: switch (device_type) {
137: #if PetscDefined(HAVE_CUDA)
138: case PETSC_DEVICE_CUDA:
139: PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlaceCyclic(dctx, hermitian_transpose, m, oldest, next, A, lda, x, stride));
140: break;
141: #endif
142: #if PetscDefined(HAVE_HIP)
143: case PETSC_DEVICE_HIP:
144: PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlaceCyclic(dctx, hermitian_transpose, m, oldest, next, A, lda, x, stride));
145: break;
146: #endif
147: default:
148: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
149: }
150: PetscFunctionReturn(PETSC_SUCCESS);
151: }