Actual source code: cd_cupm.cxx

  1: #include "../denseqn.h"
  2: #include <petsc/private/cupminterface.hpp>
  3: #include <petsc/private/cupmobject.hpp>

  5: namespace Petsc
  6: {

  8: namespace device
  9: {

 11: namespace cupm
 12: {

 14: namespace impl
 15: {

 17: template <DeviceType T>
 18: struct UpperTriangular : CUPMObject<T> {
 19:   PETSC_CUPMOBJECT_HEADER(T);

 21:   static PetscErrorCode SolveInPlace(PetscDeviceContext, PetscBool, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
 22:   static PetscErrorCode SolveInPlaceCyclic(PetscDeviceContext, PetscBool, PetscInt, PetscInt, PetscInt, const PetscScalar[], PetscInt, PetscScalar[], PetscInt) noexcept;
 23: };

 25: template <DeviceType T>
 26: PetscErrorCode UpperTriangular<T>::SolveInPlace(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt N, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
 27: {
 28:   cupmBlasInt_t    n;
 29:   cupmBlasHandle_t handle;
 30:   auto             A_ = cupmScalarPtrCast(A);
 31:   auto             x_ = cupmScalarPtrCast(x);

 33:   PetscFunctionBegin;
 34:   if (!N) PetscFunctionReturn(PETSC_SUCCESS);
 35:   PetscCall(PetscCUPMBlasIntCast(N, &n));
 36:   PetscCall(GetHandlesFrom_(dctx, &handle));
 37:   PetscCall(PetscLogGpuTimeBegin());
 38:   PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, hermitian_transpose ? CUPMBLAS_OP_C : CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n, A_, lda, x_, stride));
 39:   PetscCall(PetscLogGpuTimeEnd());

 41:   PetscCall(PetscLogGpuFlops(1.0 * N * N));
 42:   PetscFunctionReturn(PETSC_SUCCESS);
 43: }

 45: template <DeviceType T>
 46: PetscErrorCode UpperTriangular<T>::SolveInPlaceCyclic(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride) noexcept
 47: {
 48:   PetscInt              N            = next - oldest;
 49:   PetscInt              oldest_index = oldest % m;
 50:   PetscInt              next_index   = next % m;
 51:   cupmBlasInt_t         n_old, n_new;
 52:   cupmBlasPointerMode_t pointer_mode;
 53:   cupmBlasHandle_t      handle;
 54:   auto                  sone      = cupmScalarCast(1.0);
 55:   auto                  minus_one = cupmScalarCast(-1.0);
 56:   auto                  A_        = cupmScalarPtrCast(A);
 57:   auto                  x_        = cupmScalarPtrCast(x);

 59:   PetscFunctionBegin;
 60:   if (!N) PetscFunctionReturn(PETSC_SUCCESS);
 61:   PetscCall(PetscCUPMBlasIntCast(m - oldest_index, &n_old));
 62:   PetscCall(PetscCUPMBlasIntCast(next_index, &n_new));
 63:   PetscCall(GetHandlesFrom_(dctx, &handle));
 64:   PetscCall(PetscLogGpuTimeBegin());
 65:   PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
 66:   PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
 67:   if (!hermitian_transpose) {
 68:     if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_new, A_, lda, x_, stride));
 69:     if (n_new > 0 && n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_N, n_old, n_new, &minus_one, &A_[oldest_index], lda, x_, stride, &sone, &x_[oldest_index], stride));
 70:     if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_old, &A_[oldest_index * (lda + 1)], lda, &x_[oldest_index], stride));
 71:   } else {
 72:     if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_old, &A_[oldest_index * (lda + 1)], lda, &x_[oldest_index], stride));
 73:     if (n_new > 0 && n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_C, n_old, n_new, &minus_one, &A_[oldest_index], lda, &x_[oldest_index], stride, &sone, x_, stride));
 74:     if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_new, A_, lda, x_, stride));
 75:   }
 76:   PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
 77:   PetscCall(PetscLogGpuTimeEnd());

 79:   PetscCall(PetscLogGpuFlops(1.0 * N * N));
 80:   PetscFunctionReturn(PETSC_SUCCESS);
 81: }

 83: #if PetscDefined(HAVE_CUDA)
 84: template struct UpperTriangular<DeviceType::CUDA>;
 85: #endif

 87: #if PetscDefined(HAVE_HIP)
 88: template struct UpperTriangular<DeviceType::HIP>;
 89: #endif

 91: } // namespace impl

 93: } // namespace cupm

 95: } // namespace device

 97: } // namespace Petsc

 99: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_CUPM(PetscBool hermitian_transpose, PetscInt n, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
100: {
101:   using ::Petsc::device::cupm::impl::UpperTriangular;
102:   using ::Petsc::device::cupm::DeviceType;
103:   PetscDeviceContext dctx;
104:   PetscDeviceType    device_type;

106:   PetscFunctionBegin;
107:   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
108:   PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
109:   switch (device_type) {
110: #if PetscDefined(HAVE_CUDA)
111:   case PETSC_DEVICE_CUDA:
112:     PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
113:     break;
114: #endif
115: #if PetscDefined(HAVE_HIP)
116:   case PETSC_DEVICE_HIP:
117:     PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlace(dctx, hermitian_transpose, n, A, lda, x, stride));
118:     break;
119: #endif
120:   default:
121:     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
122:   }
123:   PetscFunctionReturn(PETSC_SUCCESS);
124: }

126: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlaceCyclic_CUPM(PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
127: {
128:   using ::Petsc::device::cupm::impl::UpperTriangular;
129:   using ::Petsc::device::cupm::DeviceType;
130:   PetscDeviceContext dctx;
131:   PetscDeviceType    device_type;

133:   PetscFunctionBegin;
134:   PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
135:   PetscCall(PetscDeviceContextGetDeviceType(dctx, &device_type));
136:   switch (device_type) {
137: #if PetscDefined(HAVE_CUDA)
138:   case PETSC_DEVICE_CUDA:
139:     PetscCall(UpperTriangular<DeviceType::CUDA>::SolveInPlaceCyclic(dctx, hermitian_transpose, m, oldest, next, A, lda, x, stride));
140:     break;
141: #endif
142: #if PetscDefined(HAVE_HIP)
143:   case PETSC_DEVICE_HIP:
144:     PetscCall(UpperTriangular<DeviceType::HIP>::SolveInPlaceCyclic(dctx, hermitian_transpose, m, oldest, next, A, lda, x, stride));
145:     break;
146: #endif
147:   default:
148:     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported device type %s", PetscDeviceTypes[device_type]);
149:   }
150:   PetscFunctionReturn(PETSC_SUCCESS);
151: }