Actual source code: blas_cyclic_cupm_impl.hpp
1: #pragma once
2: #include "blas_cyclic_cupm.h"
3: #include <petsc/private/cupminterface.hpp>
4: #include <petsc/private/cupmobject.hpp>
6: namespace Petsc
7: {
9: namespace device
10: {
12: namespace cupm
13: {
15: namespace impl
16: {
18: template <DeviceType T>
19: struct BLASCyclic : CUPMObject<T> {
20: PETSC_CUPMOBJECT_HEADER(T);
22: static PetscErrorCode axpby_dispatch(cupmBlasHandle_t, cupmBlasInt_t, PetscScalar, const PetscScalar[], PetscScalar, PetscScalar[], cupmBlasInt_t) noexcept;
23: static PetscErrorCode axpby(PetscDeviceContext, PetscInt, PetscInt, PetscInt, PetscScalar, const PetscScalar[], PetscScalar, PetscScalar[], PetscInt) noexcept;
24: static PetscErrorCode dmv(PetscDeviceContext, PetscBool, PetscInt, PetscInt, PetscInt, PetscScalar, const PetscScalar[], const PetscScalar[], PetscScalar, PetscScalar[]) noexcept;
25: static PetscErrorCode dsv(PetscDeviceContext, PetscBool, PetscInt, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], PetscScalar[]) noexcept;
26: static PetscErrorCode trsv(PetscDeviceContext, PetscBool, PetscInt, PetscInt, PetscInt, const PetscScalar[], PetscInt, const PetscScalar[], PetscScalar[]) noexcept;
27: static PetscErrorCode gemv(PetscDeviceContext, PetscBool, PetscInt, PetscInt, PetscInt, PetscScalar, const PetscScalar[], PetscInt, const PetscScalar[], PetscScalar, PetscScalar[]) noexcept;
28: static PetscErrorCode hemv(PetscDeviceContext, PetscInt, PetscInt, PetscInt, PetscScalar, const PetscScalar[], PetscInt, const PetscScalar[], PetscScalar, PetscScalar[]) noexcept;
29: };
31: template <DeviceType T>
32: PetscErrorCode BLASCyclic<T>::axpby_dispatch(cupmBlasHandle_t handle, cupmBlasInt_t n, PetscScalar alpha, const PetscScalar x[], PetscScalar beta, PetscScalar y[], cupmBlasInt_t y_stride) noexcept
33: {
34: auto x_ = cupmScalarPtrCast(x);
35: auto y_ = cupmScalarPtrCast(y);
36: const auto calpha = cupmScalarPtrCast(&alpha);
37: const auto cbeta = cupmScalarPtrCast(&beta);
39: PetscFunctionBegin;
40: if (alpha == 1.0 && beta == 0.0) {
41: PetscCallCUPMBLAS(cupmBlasXcopy(handle, n, x_, 1, y_, y_stride));
42: } else {
43: if (beta != 1.0) PetscCallCUPMBLAS(cupmBlasXscal(handle, n, cbeta, y_, y_stride));
44: if (alpha != 0.0) PetscCallCUPMBLAS(cupmBlasXaxpy(handle, n, calpha, x_, 1, y_, y_stride));
45: }
46: PetscFunctionReturn(PETSC_SUCCESS);
47: }
49: template <DeviceType T>
50: PetscErrorCode BLASCyclic<T>::axpby(PetscDeviceContext dctx, PetscInt M, PetscInt oldest, PetscInt next, PetscScalar alpha, const PetscScalar x[], PetscScalar beta, PetscScalar y[], PetscInt y_stride) noexcept
51: {
52: PetscInt N = next - oldest;
53: cupmBlasInt_t m, i_oldest, i_next, y_stride_;
54: cupmBlasPointerMode_t pointer_mode;
55: cupmBlasHandle_t handle;
57: PetscFunctionBegin;
58: if (!N) PetscFunctionReturn(PETSC_SUCCESS);
59: PetscCall(PetscCUPMBlasIntCast(M, &m));
60: PetscCall(PetscCUPMBlasIntCast(oldest % m, &i_oldest));
61: PetscCall(PetscCUPMBlasIntCast(((next - 1) % m) + 1, &i_next));
62: PetscCall(PetscCUPMBlasIntCast(y_stride, &y_stride_));
63: PetscCall(GetHandlesFrom_(dctx, &handle));
64: PetscCall(PetscLogGpuTimeBegin());
65: PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
66: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
67: if (N == m) {
68: PetscCall(axpby_dispatch(handle, m, alpha, x, beta, y, y_stride_));
69: } else if (i_next > i_oldest) {
70: cupmBlasInt_t diff = i_next - i_oldest;
72: PetscCall(axpby_dispatch(handle, diff, alpha, &x[i_oldest], beta, &y[i_oldest * y_stride], y_stride_));
73: } else {
74: cupmBlasInt_t diff = m - i_oldest;
76: if (i_next) PetscCall(axpby_dispatch(handle, i_next, alpha, x, beta, y, y_stride_));
77: if (diff) PetscCall(axpby_dispatch(handle, diff, alpha, &x[i_oldest], beta, &y[i_oldest * y_stride], y_stride_));
78: }
79: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
80: PetscCall(PetscLogGpuTimeEnd());
82: PetscCall(PetscLogGpuFlops(3.0 * N));
83: PetscFunctionReturn(PETSC_SUCCESS);
84: }
86: template <DeviceType T>
87: PetscErrorCode BLASCyclic<T>::dmv(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt M, PetscInt oldest, PetscInt next, PetscScalar alpha, const PetscScalar A[], const PetscScalar x[], PetscScalar beta, PetscScalar y[]) noexcept
88: {
89: PetscInt N = next - oldest;
90: cupmBlasInt_t m, i_oldest, i_next;
91: cupmBlasPointerMode_t pointer_mode;
92: cupmBlasHandle_t handle;
93: const auto A_ = cupmScalarPtrCast(A);
94: const auto x_ = cupmScalarPtrCast(x);
95: const auto y_ = cupmScalarPtrCast(y);
96: const auto calpha = cupmScalarPtrCast(&alpha);
97: const auto cbeta = cupmScalarPtrCast(&beta);
98: const auto trans = hermitian_transpose ? CUPMBLAS_OP_C : CUPMBLAS_OP_N;
100: PetscFunctionBegin;
101: if (!N) PetscFunctionReturn(PETSC_SUCCESS);
102: PetscCall(PetscCUPMBlasIntCast(M, &m));
103: PetscCall(PetscCUPMBlasIntCast(oldest % m, &i_oldest));
104: PetscCall(PetscCUPMBlasIntCast(((next - 1) % m) + 1, &i_next));
105: PetscCall(GetHandlesFrom_(dctx, &handle));
106: PetscCall(PetscLogGpuTimeBegin());
107: PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
108: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
109: if (N == m) {
110: PetscCallCUPMBLAS(cupmBlasXgbmv(handle, trans, m, m, 0, 0, calpha, A_, 1, x_, 1, cbeta, y_, 1));
111: } else if (i_next > i_oldest) {
112: cupmBlasInt_t diff = i_next - i_oldest;
114: PetscCallCUPMBLAS(cupmBlasXgbmv(handle, trans, diff, diff, 0, 0, calpha, &A_[i_oldest], 1, &x_[i_oldest], 1, cbeta, &y_[i_oldest], 1));
115: } else {
116: cupmBlasInt_t diff = m - i_oldest;
118: if (i_next) PetscCallCUPMBLAS(cupmBlasXgbmv(handle, trans, i_next, i_next, 0, 0, calpha, A_, 1, x_, 1, cbeta, y_, 1));
119: if (diff) PetscCallCUPMBLAS(cupmBlasXgbmv(handle, trans, diff, diff, 0, 0, calpha, &A_[i_oldest], 1, &x_[i_oldest], 1, cbeta, &y_[i_oldest], 1));
120: }
121: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
122: PetscCall(PetscLogGpuTimeEnd());
124: PetscCall(PetscLogGpuFlops(3.0 * N));
125: PetscFunctionReturn(PETSC_SUCCESS);
126: }
128: template <DeviceType T>
129: PetscErrorCode BLASCyclic<T>::dsv(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt M, PetscInt oldest, PetscInt next, const PetscScalar A[], const PetscScalar x[], PetscScalar y[]) noexcept
130: {
131: PetscInt N = next - oldest;
132: cupmBlasInt_t m, i_oldest, i_next;
133: cupmBlasPointerMode_t pointer_mode;
134: cupmBlasHandle_t handle;
135: cupmStream_t stream;
136: const auto A_ = cupmScalarPtrCast(A);
137: const auto y_ = cupmScalarPtrCast(y);
138: auto trans = hermitian_transpose ? CUPMBLAS_OP_C : CUPMBLAS_OP_N;
140: PetscFunctionBegin;
141: if (!N) PetscFunctionReturn(PETSC_SUCCESS);
142: PetscCall(PetscCUPMBlasIntCast(M, &m));
143: PetscCall(PetscCUPMBlasIntCast(oldest % m, &i_oldest));
144: PetscCall(PetscCUPMBlasIntCast(((next - 1) % m) + 1, &i_next));
145: PetscCall(GetHandlesFrom_(dctx, &handle, NULL, &stream));
146: PetscCall(PetscLogGpuTimeBegin());
147: PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
148: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
149: if (N == m) {
150: if (x != y) PetscCall(PetscCUPMMemcpyAsync(y, x, m, cupmMemcpyDeviceToDevice, stream));
151: PetscCallCUPMBLAS(cupmBlasXtbsv(handle, CUPMBLAS_FILL_MODE_UPPER, trans, CUPMBLAS_DIAG_NON_UNIT, m, 0, A_, 1, y_, 1));
152: } else if (i_next > i_oldest) {
153: cupmBlasInt_t diff = i_next - i_oldest;
155: if (x != y) PetscCall(PetscCUPMMemcpyAsync(&y[i_oldest], &x[i_oldest], diff, cupmMemcpyDeviceToDevice, stream));
156: PetscCallCUPMBLAS(cupmBlasXtbsv(handle, CUPMBLAS_FILL_MODE_UPPER, trans, CUPMBLAS_DIAG_NON_UNIT, diff, 0, &A_[i_oldest], 1, &y_[i_oldest], 1));
157: } else {
158: cupmBlasInt_t diff = m - i_oldest;
160: if (i_next) {
161: if (x != y) PetscCall(PetscCUPMMemcpyAsync(y, x, i_next, cupmMemcpyDeviceToDevice, stream));
162: PetscCallCUPMBLAS(cupmBlasXtbsv(handle, CUPMBLAS_FILL_MODE_UPPER, trans, CUPMBLAS_DIAG_NON_UNIT, i_next, 0, A_, 1, y_, 1));
163: }
164: if (diff) {
165: if (x != y) PetscCall(PetscCUPMMemcpyAsync(&y[i_oldest], &x[i_oldest], diff, cupmMemcpyDeviceToDevice, stream));
166: PetscCallCUPMBLAS(cupmBlasXtbsv(handle, CUPMBLAS_FILL_MODE_UPPER, trans, CUPMBLAS_DIAG_NON_UNIT, diff, 0, &A_[i_oldest], 1, &y_[i_oldest], 1));
167: }
168: }
169: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
170: PetscCall(PetscLogGpuTimeEnd());
172: PetscCall(PetscLogGpuFlops(3.0 * N));
173: PetscFunctionReturn(PETSC_SUCCESS);
174: }
176: template <DeviceType T>
177: PetscErrorCode BLASCyclic<T>::trsv(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, const PetscScalar x[], PetscScalar y[]) noexcept
178: {
179: PetscInt N = next - oldest;
180: PetscInt i_oldest = oldest % m;
181: PetscInt i_next = ((next - 1) % m) + 1;
182: cupmBlasInt_t n, n_old, n_new;
183: cupmBlasPointerMode_t pointer_mode;
184: cupmBlasHandle_t handle;
185: cupmStream_t stream;
186: auto sone = cupmScalarCast(1.0);
187: auto minus_one = cupmScalarCast(-1.0);
188: auto A_ = cupmScalarPtrCast(A);
189: auto y_ = cupmScalarPtrCast(y);
191: PetscFunctionBegin;
192: if (!N) PetscFunctionReturn(PETSC_SUCCESS);
193: PetscCall(PetscCUPMBlasIntCast(i_next - i_oldest, &n));
194: PetscCall(PetscCUPMBlasIntCast(m - i_oldest, &n_old));
195: PetscCall(PetscCUPMBlasIntCast(i_next, &n_new));
196: PetscCall(GetHandlesFrom_(dctx, &handle, NULL, &stream));
197: PetscCall(PetscLogGpuTimeBegin());
198: PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
199: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
200: if (n > 0) {
201: if (x != y) PetscCall(PetscCUPMMemcpyAsync(y, x, n, cupmMemcpyDeviceToDevice, stream));
202: PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, hermitian_transpose ? CUPMBLAS_OP_C : CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n, &A_[i_oldest * (lda + 1)], lda, y_, 1));
203: } else if (!hermitian_transpose) {
204: if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_new, A_, lda, y_, 1));
205: if (n_new > 0 && n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_N, n_old, n_new, &minus_one, &A_[i_oldest], lda, y_, 1, &sone, &y_[i_oldest], 1));
206: if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_N, CUPMBLAS_DIAG_NON_UNIT, n_old, &A_[i_oldest * (lda + 1)], lda, &y_[i_oldest], 1));
207: } else {
208: if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_old, &A_[i_oldest * (lda + 1)], lda, &y_[i_oldest], 1));
209: if (n_new > 0 && n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_C, n_old, n_new, &minus_one, &A_[i_oldest], lda, &y_[i_oldest], 1, &sone, y_, 1));
210: if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXtrsv(handle, CUPMBLAS_FILL_MODE_UPPER, CUPMBLAS_OP_C, CUPMBLAS_DIAG_NON_UNIT, n_new, A_, lda, y_, 1));
211: }
212: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
213: PetscCall(PetscLogGpuTimeEnd());
215: PetscCall(PetscLogGpuFlops(1.0 * N * N));
216: PetscFunctionReturn(PETSC_SUCCESS);
217: }
219: template <DeviceType T>
220: PetscErrorCode BLASCyclic<T>::hemv(PetscDeviceContext dctx, PetscInt m, PetscInt oldest, PetscInt next, PetscScalar alpha, const PetscScalar A[], PetscInt lda, const PetscScalar x[], PetscScalar beta, PetscScalar y[]) noexcept
221: {
222: PetscInt N = next - oldest;
223: PetscInt i_oldest = oldest % m;
224: PetscInt i_next = ((next - 1) % m) + 1;
225: cupmBlasInt_t n, n_old, n_new;
226: cupmBlasPointerMode_t pointer_mode;
227: cupmBlasHandle_t handle;
228: cupmStream_t stream;
229: auto sone = cupmScalarCast(1.0);
230: auto A_ = cupmScalarPtrCast(A);
231: auto x_ = cupmScalarPtrCast(x);
232: auto y_ = cupmScalarPtrCast(y);
233: const auto calpha = cupmScalarPtrCast(&alpha);
234: const auto cbeta = cupmScalarPtrCast(&beta);
236: PetscFunctionBegin;
237: if (!N) PetscFunctionReturn(PETSC_SUCCESS);
238: PetscCall(PetscCUPMBlasIntCast(i_next - i_oldest, &n));
239: PetscCall(PetscCUPMBlasIntCast(m - i_oldest, &n_old));
240: PetscCall(PetscCUPMBlasIntCast(i_next, &n_new));
241: PetscCall(GetHandlesFrom_(dctx, &handle, NULL, &stream));
242: PetscCall(PetscLogGpuTimeBegin());
243: PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
244: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
245: if (n > 0) {
246: PetscCallCUPMBLAS(cupmBlasXhemv(handle, CUPMBLAS_FILL_MODE_UPPER, n, calpha, &A_[i_oldest * (lda + 1)], lda, &x_[i_oldest], 1, cbeta, &y_[i_oldest], 1));
247: } else {
248: if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXhemv(handle, CUPMBLAS_FILL_MODE_UPPER, n_new, calpha, A_, lda, x_, 1, cbeta, y_, 1));
249: if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXhemv(handle, CUPMBLAS_FILL_MODE_UPPER, n_old, calpha, &A_[i_oldest * (lda + 1)], lda, &x_[i_oldest], 1, cbeta, &y_[i_oldest], 1));
250: if (n_new > 0 && n_old > 0) {
251: PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_N, n_old, n_new, calpha, &A_[i_oldest], lda, x_, 1, &sone, &y_[i_oldest], 1));
252: PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_C, n_old, n_new, calpha, &A_[i_oldest], lda, &x_[i_oldest], 1, &sone, y_, 1));
253: }
254: }
255: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
256: PetscCall(PetscLogGpuTimeEnd());
258: PetscCall(PetscLogGpuFlops(2.0 * N * N));
259: PetscFunctionReturn(PETSC_SUCCESS);
260: }
262: template <DeviceType T>
263: PetscErrorCode BLASCyclic<T>::gemv(PetscDeviceContext dctx, PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, PetscScalar alpha, const PetscScalar A[], PetscInt lda, const PetscScalar x[], PetscScalar beta, PetscScalar y[]) noexcept
264: {
265: PetscInt N = next - oldest;
266: PetscInt i_oldest = oldest % m;
267: PetscInt i_next = ((next - 1) % m) + 1;
268: cupmBlasInt_t n, n_old, n_new;
269: cupmBlasPointerMode_t pointer_mode;
270: cupmBlasHandle_t handle;
271: cupmStream_t stream;
272: auto sone = cupmScalarCast(1.0);
273: auto A_ = cupmScalarPtrCast(A);
274: auto x_ = cupmScalarPtrCast(x);
275: auto y_ = cupmScalarPtrCast(y);
276: auto trans = hermitian_transpose ? CUPMBLAS_OP_C : CUPMBLAS_OP_N;
277: const auto calpha = cupmScalarPtrCast(&alpha);
278: const auto cbeta = cupmScalarPtrCast(&beta);
280: PetscFunctionBegin;
281: if (!N) PetscFunctionReturn(PETSC_SUCCESS);
282: PetscCall(PetscCUPMBlasIntCast(i_next - i_oldest, &n));
283: PetscCall(PetscCUPMBlasIntCast(m - i_oldest, &n_old));
284: PetscCall(PetscCUPMBlasIntCast(i_next, &n_new));
285: PetscCall(GetHandlesFrom_(dctx, &handle, NULL, &stream));
286: PetscCall(PetscLogGpuTimeBegin());
287: PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &pointer_mode));
288: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_HOST));
289: if (N == m) {
290: PetscCallCUPMBLAS(cupmBlasXgemv(handle, trans, N, N, calpha, A_, lda, x_, 1, cbeta, y_, 1));
291: } else if (n > 0) {
292: PetscCallCUPMBLAS(cupmBlasXgemv(handle, trans, n, n, calpha, &A_[i_oldest * (lda + 1)], lda, &x_[i_oldest], 1, cbeta, &y_[i_oldest], 1));
293: } else {
294: if (n_new > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, trans, n_new, n_new, calpha, A_, lda, x_, 1, cbeta, y_, 1));
295: if (n_old > 0) PetscCallCUPMBLAS(cupmBlasXgemv(handle, trans, n_old, n_old, calpha, &A_[i_oldest * (lda + 1)], lda, &x_[i_oldest], 1, cbeta, &y_[i_oldest], 1));
296: if (n_new > 0 && n_old > 0) {
297: if (!hermitian_transpose) {
298: PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_N, n_old, n_new, calpha, &A_[i_oldest], lda, x_, 1, &sone, &y_[i_oldest], 1));
299: PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_N, n_new, n_old, calpha, &A_[i_oldest * lda], lda, &x_[i_oldest], 1, &sone, y_, 1));
300: } else {
301: PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_C, n_new, n_old, calpha, &A_[i_oldest * lda], lda, x_, 1, &sone, &y_[i_oldest], 1));
302: PetscCallCUPMBLAS(cupmBlasXgemv(handle, CUPMBLAS_OP_C, n_old, n_new, calpha, &A_[i_oldest], lda, &x_[i_oldest], 1, &sone, y_, 1));
303: }
304: }
305: }
306: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, pointer_mode));
307: PetscCall(PetscLogGpuTimeEnd());
309: PetscCall(PetscLogGpuFlops(2.0 * N * N));
310: PetscFunctionReturn(PETSC_SUCCESS);
311: }
313: } // namespace impl
315: } // namespace cupm
317: } // namespace device
319: } // namespace Petsc