Actual source code: vecseqcupm_impl.hpp

  1: #pragma once

  3: #include "vecseqcupm.hpp"

  5: #include <petsc/private/randomimpl.h>

  7: #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
  8: #include "../src/sys/objects/device/impls/cupm/kernels.hpp"

 10: #if PetscDefined(USE_COMPLEX)
 11:   #include <thrust/transform_reduce.h>
 12: #endif
 13: #include <thrust/transform.h>
 14: #include <thrust/reduce.h>
 15: #include <thrust/functional.h>
 16: #include <thrust/tuple.h>
 17: #include <thrust/device_ptr.h>
 18: #include <thrust/iterator/zip_iterator.h>
 19: #include <thrust/iterator/counting_iterator.h>
 20: #include <thrust/iterator/constant_iterator.h>
 21: #include <thrust/inner_product.h>

 23: namespace Petsc
 24: {

 26: namespace vec
 27: {

 29: namespace cupm
 30: {

 32: namespace impl
 33: {

 35: // ==========================================================================================
 36: // VecSeq_CUPM - Private API
 37: // ==========================================================================================

 39: template <device::cupm::DeviceType T>
 40: inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
 41: {
 42:   return static_cast<Vec_Seq *>(v->data);
 43: }

 45: template <device::cupm::DeviceType T>
 46: inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
 47: {
 48:   return VECSEQCUPM();
 49: }

 51: template <device::cupm::DeviceType T>
 52: inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept
 53: {
 54:   return VECSEQ;
 55: }

 57: template <device::cupm::DeviceType T>
 58: inline PetscErrorCode VecSeq_CUPM<T>::ClearAsyncFunctions(Vec v) noexcept
 59: {
 60:   PetscFunctionBegin;
 61:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Abs), nullptr));
 62:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBY), nullptr));
 63:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBYPCZ), nullptr));
 64:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPY), nullptr));
 65:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AYPX), nullptr));
 66:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Conjugate), nullptr));
 67:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Copy), nullptr));
 68:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Exp), nullptr));
 69:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Log), nullptr));
 70:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(MAXPY), nullptr));
 71:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseDivide), nullptr));
 72:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMax), nullptr));
 73:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMaxAbs), nullptr));
 74:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMin), nullptr));
 75:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMult), nullptr));
 76:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Reciprocal), nullptr));
 77:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Scale), nullptr));
 78:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Set), nullptr));
 79:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Shift), nullptr));
 80:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(SqrtAbs), nullptr));
 81:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Swap), nullptr));
 82:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(WAXPY), nullptr));
 83:   PetscFunctionReturn(PETSC_SUCCESS);
 84: }

 86: template <device::cupm::DeviceType T>
 87: inline PetscErrorCode VecSeq_CUPM<T>::InitializeAsyncFunctions(Vec v) noexcept
 88: {
 89:   PetscFunctionBegin;
 90:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Abs), VecSeq_CUPM<T>::AbsAsync));
 91:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBY), VecSeq_CUPM<T>::AXPBYAsync));
 92:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBYPCZ), VecSeq_CUPM<T>::AXPBYPCZAsync));
 93:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPY), VecSeq_CUPM<T>::AXPYAsync));
 94:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AYPX), VecSeq_CUPM<T>::AYPXAsync));
 95:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Conjugate), VecSeq_CUPM<T>::ConjugateAsync));
 96:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Copy), VecSeq_CUPM<T>::CopyAsync));
 97:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Exp), VecSeq_CUPM<T>::ExpAsync));
 98:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Log), VecSeq_CUPM<T>::LogAsync));
 99:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(MAXPY), VecSeq_CUPM<T>::MAXPYAsync));
100:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseDivide), VecSeq_CUPM<T>::PointwiseDivideAsync));
101:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMax), VecSeq_CUPM<T>::PointwiseMaxAsync));
102:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMaxAbs), VecSeq_CUPM<T>::PointwiseMaxAbsAsync));
103:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMin), VecSeq_CUPM<T>::PointwiseMinAsync));
104:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMult), VecSeq_CUPM<T>::PointwiseMultAsync));
105:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Reciprocal), VecSeq_CUPM<T>::ReciprocalAsync));
106:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Scale), VecSeq_CUPM<T>::ScaleAsync));
107:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Set), VecSeq_CUPM<T>::SetAsync));
108:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Shift), VecSeq_CUPM<T>::ShiftAsync));
109:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(SqrtAbs), VecSeq_CUPM<T>::SqrtAbsAsync));
110:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Swap), VecSeq_CUPM<T>::SwapAsync));
111:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(WAXPY), VecSeq_CUPM<T>::WAXPYAsync));
112:   PetscFunctionReturn(PETSC_SUCCESS);
113: }

115: template <device::cupm::DeviceType T>
116: inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
117: {
118:   PetscFunctionBegin;
119:   PetscCall(ClearAsyncFunctions(v));
120:   PetscCall(VecDestroy_Seq(v));
121:   PetscFunctionReturn(PETSC_SUCCESS);
122: }

124: template <device::cupm::DeviceType T>
125: inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
126: {
127:   return VecResetArray_Seq(v);
128: }

130: template <device::cupm::DeviceType T>
131: inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
132: {
133:   return VecPlaceArray_Seq(v, a);
134: }

136: template <device::cupm::DeviceType T>
137: inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
138: {
139:   PetscMPIInt size;

141:   PetscFunctionBegin;
142:   if (alloc_missing) *alloc_missing = PETSC_FALSE;
143:   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
144:   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
145:   PetscCall(VecCreate_Seq_Private(v, host_array));
146:   PetscCall(InitializeAsyncFunctions(v));
147:   PetscFunctionReturn(PETSC_SUCCESS);
148: }

150: // for functions with an early return based one vec size we still need to artificially bump the
151: // object state. This is to prevent the following:
152: //
153: // 0. Suppose you have a Vec {
154: //   rank 0: [0],
155: //   rank 1: []
156: // }
157: // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
158: // 2. Vec enters e.g. VecSet(10)
159: // 3. rank 1 has local size 0 and bails immediately
160: // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
161: // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
162: // 6. Vec enters VecNorm(), and calls VecNormAvailable()
163: // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
164: // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
165: // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
166: template <device::cupm::DeviceType T>
167: inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
168: {
169:   PetscFunctionBegin;
170:   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
171:   PetscFunctionReturn(PETSC_SUCCESS);
172: }

174: template <device::cupm::DeviceType T>
175: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
176: {
177:   PetscFunctionBegin;
178:   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
179:   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
180:   PetscFunctionReturn(PETSC_SUCCESS);
181: }

183: template <device::cupm::DeviceType T>
184: template <typename BinaryFuncT>
185: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout, PetscDeviceContext dctx) noexcept
186: {
187:   PetscFunctionBegin;
188:   if (const auto n = zout->map->n) {
189:     cupmStream_t stream;

191:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
192:     PetscCall(GetHandlesFrom_(dctx, &stream));
193:     // clang-format off
194:     PetscCallThrust(
195:       const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());

197:       THRUST_CALL(
198:         thrust::transform,
199:         stream,
200:         dxptr, dxptr + n,
201:         thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
202:         thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
203:         std::forward<BinaryFuncT>(binary)
204:       )
205:     );
206:     // clang-format on
207:     PetscCall(PetscLogGpuFlops(n));
208:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
209:   } else {
210:     PetscCall(MaybeIncrementEmptyLocalVec(zout));
211:   }
212:   PetscFunctionReturn(PETSC_SUCCESS);
213: }

215: template <device::cupm::DeviceType T>
216: template <typename BinaryFuncT>
217: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinaryDispatch_(PetscErrorCode (*VecSeqFunction)(Vec, Vec, Vec), BinaryFuncT &&binary, Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
218: {
219:   PetscFunctionBegin;
220:   if (xin->boundtocpu || yin->boundtocpu) {
221:     PetscCall((*VecSeqFunction)(wout, xin, yin));
222:   } else {
223:     // note order of arguments! xin and yin are read, wout is written!
224:     PetscCall(PointwiseBinary_(std::forward<BinaryFuncT>(binary), xin, yin, wout, dctx));
225:   }
226:   PetscFunctionReturn(PETSC_SUCCESS);
227: }

229: template <device::cupm::DeviceType T>
230: template <typename UnaryFuncT>
231: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yin, PetscDeviceContext dctx) noexcept
232: {
233:   const auto inplace = !yin || (xinout == yin);

235:   PetscFunctionBegin;
236:   if (const auto n = xinout->map->n) {
237:     cupmStream_t stream;
238:     const auto   apply = [&](PetscScalar *xinout, PetscScalar *yin = nullptr) {
239:       PetscFunctionBegin;
240:       // clang-format off
241:       PetscCallThrust(
242:         const auto xptr = thrust::device_pointer_cast(xinout);

244:         THRUST_CALL(
245:           thrust::transform,
246:           stream,
247:           xptr, xptr + n,
248:           (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr,
249:           std::forward<UnaryFuncT>(unary)
250:         )
251:       );
252:       // clang-format on
253:       PetscFunctionReturn(PETSC_SUCCESS);
254:     };

256:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
257:     PetscCall(GetHandlesFrom_(dctx, &stream));
258:     if (inplace) {
259:       PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
260:     } else {
261:       PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
262:     }
263:     PetscCall(PetscLogGpuFlops(n));
264:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
265:   } else {
266:     if (inplace) {
267:       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
268:     } else {
269:       PetscCall(MaybeIncrementEmptyLocalVec(yin));
270:     }
271:   }
272:   PetscFunctionReturn(PETSC_SUCCESS);
273: }

275: // ==========================================================================================
276: // VecSeq_CUPM - Public API - Constructors
277: // ==========================================================================================

279: // VecCreateSeqCUPM()
280: template <device::cupm::DeviceType T>
281: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
282: {
283:   PetscFunctionBegin;
284:   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
285:   PetscFunctionReturn(PETSC_SUCCESS);
286: }

288: // VecCreateSeqCUPMWithArrays()
289: template <device::cupm::DeviceType T>
290: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
291: {
292:   PetscDeviceContext dctx;

294:   PetscFunctionBegin;
295:   PetscCall(GetHandles_(&dctx));
296:   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
297:   // CreateSeqCUPM_() is called!
298:   PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE));
299:   PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
300:   PetscFunctionReturn(PETSC_SUCCESS);
301: }

303: // v->ops->duplicate
304: template <device::cupm::DeviceType T>
305: inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
306: {
307:   PetscDeviceContext dctx;

309:   PetscFunctionBegin;
310:   PetscCall(GetHandles_(&dctx));
311:   PetscCall(Duplicate_CUPMBase(v, y, dctx));
312:   PetscFunctionReturn(PETSC_SUCCESS);
313: }

315: // ==========================================================================================
316: // VecSeq_CUPM - Public API - Utility
317: // ==========================================================================================

319: // v->ops->bindtocpu
320: template <device::cupm::DeviceType T>
321: inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
322: {
323:   PetscDeviceContext dctx;

325:   PetscFunctionBegin;
326:   PetscCall(GetHandles_(&dctx));
327:   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));

329:   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
330:   VecSetOp_CUPM(dot, VecDot_Seq, Dot);
331:   VecSetOp_CUPM(norm, VecNorm_Seq, Norm);
332:   VecSetOp_CUPM(tdot, VecTDot_Seq, TDot);
333:   VecSetOp_CUPM(mdot, VecMDot_Seq, MDot);
334:   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
335:   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
336:   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
337:   VecSetOp_CUPM(max, VecMax_Seq, Max);
338:   VecSetOp_CUPM(min, VecMin_Seq, Min);
339:   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO);
340:   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO);
341:   PetscFunctionReturn(PETSC_SUCCESS);
342: }

344: // ==========================================================================================
345: // VecSeq_CUPM - Public API - Mutators
346: // ==========================================================================================

348: // v->ops->getlocalvector or v->ops->getlocalvectorread
349: template <device::cupm::DeviceType T>
350: template <PetscMemoryAccessMode access>
351: inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept
352: {
353:   PetscBool wisseqcupm;

355:   PetscFunctionBegin;
356:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
357:   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
358:   if (wisseqcupm) {
359:     if (const auto wseq = VecIMPLCast(w)) {
360:       if (auto &alloced = wseq->array_allocated) {
361:         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));

363:         PetscCall(PetscFree(alloced));
364:       }
365:       wseq->array         = nullptr;
366:       wseq->unplacedarray = nullptr;
367:     }
368:     if (const auto wcu = VecCUPMCast(w)) {
369:       if (auto &device_array = wcu->array_d) {
370:         cupmStream_t stream;

372:         PetscCall(GetHandles_(&stream));
373:         PetscCallCUPM(cupmFreeAsync(device_array, stream));
374:       }
375:       PetscCall(PetscFree(w->spptr /* wcu */));
376:     }
377:   }
378:   if (v->petscnative && wisseqcupm) {
379:     PetscCall(PetscFree(w->data));
380:     w->data          = v->data;
381:     w->offloadmask   = v->offloadmask;
382:     w->pinned_memory = v->pinned_memory;
383:     w->spptr         = v->spptr;
384:     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
385:   } else {
386:     const auto array = &VecIMPLCast(w)->array;

388:     if (access == PETSC_MEMORY_ACCESS_READ) {
389:       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
390:     } else {
391:       PetscCall(VecGetArray(v, array));
392:     }
393:     w->offloadmask = PETSC_OFFLOAD_CPU;
394:     if (wisseqcupm) {
395:       PetscDeviceContext dctx;

397:       PetscCall(GetHandles_(&dctx));
398:       PetscCall(DeviceAllocateCheck_(dctx, w));
399:     }
400:   }
401:   PetscFunctionReturn(PETSC_SUCCESS);
402: }

404: // v->ops->restorelocalvector or v->ops->restorelocalvectorread
405: template <device::cupm::DeviceType T>
406: template <PetscMemoryAccessMode access>
407: inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept
408: {
409:   PetscBool wisseqcupm;

411:   PetscFunctionBegin;
412:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
413:   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
414:   if (v->petscnative && wisseqcupm) {
415:     // the assignments to nullptr are __critical__, as w may persist after this call returns
416:     // and shouldn't share data with v!
417:     v->pinned_memory = w->pinned_memory;
418:     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
419:     v->data          = util::exchange(w->data, nullptr);
420:     v->spptr         = util::exchange(w->spptr, nullptr);
421:   } else {
422:     const auto array = &VecIMPLCast(w)->array;

424:     if (access == PETSC_MEMORY_ACCESS_READ) {
425:       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
426:     } else {
427:       PetscCall(VecRestoreArray(v, array));
428:     }
429:     if (w->spptr && wisseqcupm) {
430:       cupmStream_t stream;

432:       PetscCall(GetHandles_(&stream));
433:       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
434:       PetscCall(PetscFree(w->spptr));
435:     }
436:   }
437:   PetscFunctionReturn(PETSC_SUCCESS);
438: }

440: // ==========================================================================================
441: // VecSeq_CUPM - Public API - Compute Methods
442: // ==========================================================================================

444: // VecAYPXAsync_Private
445: template <device::cupm::DeviceType T>
446: inline PetscErrorCode VecSeq_CUPM<T>::AYPXAsync(Vec yin, PetscScalar alpha, Vec xin, PetscDeviceContext dctx) noexcept
447: {
448:   const auto n = static_cast<cupmBlasInt_t>(yin->map->n);
449:   PetscBool  xiscupm;

451:   PetscFunctionBegin;
452:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
453:   if (!xiscupm) {
454:     PetscCall(VecAYPX_Seq(yin, alpha, xin));
455:     PetscFunctionReturn(PETSC_SUCCESS);
456:   }
457:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
458:   if (alpha == PetscScalar(0.0)) {
459:     cupmStream_t stream;

461:     PetscCall(GetHandlesFrom_(dctx, &stream));
462:     PetscCall(PetscLogGpuTimeBegin());
463:     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
464:     PetscCall(PetscLogGpuTimeEnd());
465:   } else if (n) {
466:     const auto       alphaIsOne = alpha == PetscScalar(1.0);
467:     const auto       calpha     = cupmScalarPtrCast(&alpha);
468:     cupmBlasHandle_t cupmBlasHandle;

470:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
471:     {
472:       const auto yptr = DeviceArrayReadWrite(dctx, yin);
473:       const auto xptr = DeviceArrayRead(dctx, xin);

475:       PetscCall(PetscLogGpuTimeBegin());
476:       if (alphaIsOne) {
477:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
478:       } else {
479:         const auto one = cupmScalarCast(1.0);

481:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
482:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
483:       }
484:       PetscCall(PetscLogGpuTimeEnd());
485:     }
486:     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
487:   }
488:   if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
489:   PetscFunctionReturn(PETSC_SUCCESS);
490: }

492: // v->ops->aypx
493: template <device::cupm::DeviceType T>
494: inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept
495: {
496:   PetscFunctionBegin;
497:   PetscCall(AYPXAsync(yin, alpha, xin, nullptr));
498:   PetscFunctionReturn(PETSC_SUCCESS);
499: }

501: // VecAXPYAsync_Private
502: template <device::cupm::DeviceType T>
503: inline PetscErrorCode VecSeq_CUPM<T>::AXPYAsync(Vec yin, PetscScalar alpha, Vec xin, PetscDeviceContext dctx) noexcept
504: {
505:   PetscBool xiscupm;

507:   PetscFunctionBegin;
508:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
509:   if (xiscupm) {
510:     const auto       n = static_cast<cupmBlasInt_t>(yin->map->n);
511:     cupmBlasHandle_t cupmBlasHandle;

513:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
514:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
515:     PetscCall(PetscLogGpuTimeBegin());
516:     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
517:     PetscCall(PetscLogGpuTimeEnd());
518:     PetscCall(PetscLogGpuFlops(2 * n));
519:     if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
520:   } else {
521:     PetscCall(VecAXPY_Seq(yin, alpha, xin));
522:   }
523:   PetscFunctionReturn(PETSC_SUCCESS);
524: }

526: // v->ops->axpy
527: template <device::cupm::DeviceType T>
528: inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept
529: {
530:   PetscFunctionBegin;
531:   PetscCall(AXPYAsync(yin, alpha, xin, nullptr));
532:   PetscFunctionReturn(PETSC_SUCCESS);
533: }

535: namespace detail
536: {

538: struct divides {
539:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return rhs == PetscScalar{0.0} ? rhs : lhs / rhs; }
540: };

542: } // namespace detail

544: // VecPointwiseDivideAsync_Private
545: template <device::cupm::DeviceType T>
546: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivideAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
547: {
548:   PetscFunctionBegin;
549:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseDivide_Seq, detail::divides{}, wout, xin, yin, dctx));
550:   PetscFunctionReturn(PETSC_SUCCESS);
551: }

553: // v->ops->pointwisedivide
554: template <device::cupm::DeviceType T>
555: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec wout, Vec xin, Vec yin) noexcept
556: {
557:   PetscFunctionBegin;
558:   PetscCall(PointwiseDivideAsync(wout, xin, yin, nullptr));
559:   PetscFunctionReturn(PETSC_SUCCESS);
560: }

562: // VecPointwiseMultAsync_Private
563: template <device::cupm::DeviceType T>
564: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMultAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
565: {
566:   PetscFunctionBegin;
567:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMult_Seq, thrust::multiplies<PetscScalar>{}, wout, xin, yin, dctx));
568:   PetscFunctionReturn(PETSC_SUCCESS);
569: }

571: // v->ops->pointwisemult
572: template <device::cupm::DeviceType T>
573: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec wout, Vec xin, Vec yin) noexcept
574: {
575:   PetscFunctionBegin;
576:   PetscCall(PointwiseMultAsync(wout, xin, yin, nullptr));
577:   PetscFunctionReturn(PETSC_SUCCESS);
578: }

580: namespace detail
581: {

583: struct MaximumRealPart {
584:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::maximum<PetscReal>{}(PetscRealPart(lhs), PetscRealPart(rhs)); }
585: };

587: } // namespace detail

589: // VecPointwiseMaxAsync_Private
590: template <device::cupm::DeviceType T>
591: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
592: {
593:   PetscFunctionBegin;
594:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMax_Seq, detail::MaximumRealPart{}, wout, xin, yin, dctx));
595:   PetscFunctionReturn(PETSC_SUCCESS);
596: }

598: // v->ops->pointwisemax
599: template <device::cupm::DeviceType T>
600: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMax(Vec wout, Vec xin, Vec yin) noexcept
601: {
602:   PetscFunctionBegin;
603:   PetscCall(PointwiseMaxAsync(wout, xin, yin, nullptr));
604:   PetscFunctionReturn(PETSC_SUCCESS);
605: }

607: namespace detail
608: {

610: struct MaximumAbsoluteValue {
611:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::maximum<PetscReal>{}(PetscAbsScalar(lhs), PetscAbsScalar(rhs)); }
612: };

614: } // namespace detail

616: // VecPointwiseMaxAbsAsync_Private
617: template <device::cupm::DeviceType T>
618: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAbsAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
619: {
620:   PetscFunctionBegin;
621:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMaxAbs_Seq, detail::MaximumAbsoluteValue{}, wout, xin, yin, dctx));
622:   PetscFunctionReturn(PETSC_SUCCESS);
623: }

625: // v->ops->pointwisemaxabs
626: template <device::cupm::DeviceType T>
627: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAbs(Vec wout, Vec xin, Vec yin) noexcept
628: {
629:   PetscFunctionBegin;
630:   PetscCall(PointwiseMaxAbsAsync(wout, xin, yin, nullptr));
631:   PetscFunctionReturn(PETSC_SUCCESS);
632: }

634: namespace detail
635: {

637: struct MinimumRealPart {
638:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::minimum<PetscReal>{}(PetscRealPart(lhs), PetscRealPart(rhs)); }
639: };

641: } // namespace detail

643: // VecPointwiseMinAsync_Private
644: template <device::cupm::DeviceType T>
645: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMinAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
646: {
647:   PetscFunctionBegin;
648:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMin_Seq, detail::MinimumRealPart{}, wout, xin, yin, dctx));
649:   PetscFunctionReturn(PETSC_SUCCESS);
650: }

652: // v->ops->pointwisemin
653: template <device::cupm::DeviceType T>
654: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMin(Vec wout, Vec xin, Vec yin) noexcept
655: {
656:   PetscFunctionBegin;
657:   PetscCall(PointwiseMinAsync(wout, xin, yin, nullptr));
658:   PetscFunctionReturn(PETSC_SUCCESS);
659: }

661: namespace detail
662: {

664: struct Reciprocal {
665:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept
666:   {
667:     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
668:     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
669:     // everything in PetscScalar...
670:     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
671:   }
672: };

674: } // namespace detail

676: // VecReciprocalAsync_Private
677: template <device::cupm::DeviceType T>
678: inline PetscErrorCode VecSeq_CUPM<T>::ReciprocalAsync(Vec xin, PetscDeviceContext dctx) noexcept
679: {
680:   PetscFunctionBegin;
681:   PetscCall(PointwiseUnary_(detail::Reciprocal{}, xin, nullptr, dctx));
682:   PetscFunctionReturn(PETSC_SUCCESS);
683: }

685: // v->ops->reciprocal
686: template <device::cupm::DeviceType T>
687: inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept
688: {
689:   PetscFunctionBegin;
690:   PetscCall(ReciprocalAsync(xin, nullptr));
691:   PetscFunctionReturn(PETSC_SUCCESS);
692: }

694: namespace detail
695: {

697: struct AbsoluteValue {
698:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscAbsScalar(s); }
699: };

701: } // namespace detail

703: // VecAbsAsync_Private
704: template <device::cupm::DeviceType T>
705: inline PetscErrorCode VecSeq_CUPM<T>::AbsAsync(Vec xin, PetscDeviceContext dctx) noexcept
706: {
707:   PetscFunctionBegin;
708:   PetscCall(PointwiseUnary_(detail::AbsoluteValue{}, xin, nullptr, dctx));
709:   PetscFunctionReturn(PETSC_SUCCESS);
710: }

712: // v->ops->abs
713: template <device::cupm::DeviceType T>
714: inline PetscErrorCode VecSeq_CUPM<T>::Abs(Vec xin) noexcept
715: {
716:   PetscFunctionBegin;
717:   PetscCall(AbsAsync(xin, nullptr));
718:   PetscFunctionReturn(PETSC_SUCCESS);
719: }

721: namespace detail
722: {

724: struct SquareRootAbsoluteValue {
725:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscSqrtReal(PetscAbsScalar(s)); }
726: };

728: } // namespace detail

730: // VecSqrtAbsAsync_Private
731: template <device::cupm::DeviceType T>
732: inline PetscErrorCode VecSeq_CUPM<T>::SqrtAbsAsync(Vec xin, PetscDeviceContext dctx) noexcept
733: {
734:   PetscFunctionBegin;
735:   PetscCall(PointwiseUnary_(detail::SquareRootAbsoluteValue{}, xin, nullptr, dctx));
736:   PetscFunctionReturn(PETSC_SUCCESS);
737: }

739: // v->ops->sqrt
740: template <device::cupm::DeviceType T>
741: inline PetscErrorCode VecSeq_CUPM<T>::SqrtAbs(Vec xin) noexcept
742: {
743:   PetscFunctionBegin;
744:   PetscCall(SqrtAbsAsync(xin, nullptr));
745:   PetscFunctionReturn(PETSC_SUCCESS);
746: }

748: namespace detail
749: {

751: struct Exponent {
752:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscExpScalar(s); }
753: };

755: } // namespace detail

757: // VecExpAsync_Private
758: template <device::cupm::DeviceType T>
759: inline PetscErrorCode VecSeq_CUPM<T>::ExpAsync(Vec xin, PetscDeviceContext dctx) noexcept
760: {
761:   PetscFunctionBegin;
762:   PetscCall(PointwiseUnary_(detail::Exponent{}, xin, nullptr, dctx));
763:   PetscFunctionReturn(PETSC_SUCCESS);
764: }

766: // v->ops->exp
767: template <device::cupm::DeviceType T>
768: inline PetscErrorCode VecSeq_CUPM<T>::Exp(Vec xin) noexcept
769: {
770:   PetscFunctionBegin;
771:   PetscCall(ExpAsync(xin, nullptr));
772:   PetscFunctionReturn(PETSC_SUCCESS);
773: }

775: namespace detail
776: {

778: struct Logarithm {
779:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscLogScalar(s); }
780: };

782: } // namespace detail

784: // VecLogAsync_Private
785: template <device::cupm::DeviceType T>
786: inline PetscErrorCode VecSeq_CUPM<T>::LogAsync(Vec xin, PetscDeviceContext dctx) noexcept
787: {
788:   PetscFunctionBegin;
789:   PetscCall(PointwiseUnary_(detail::Logarithm{}, xin, nullptr, dctx));
790:   PetscFunctionReturn(PETSC_SUCCESS);
791: }

793: // v->ops->log
794: template <device::cupm::DeviceType T>
795: inline PetscErrorCode VecSeq_CUPM<T>::Log(Vec xin) noexcept
796: {
797:   PetscFunctionBegin;
798:   PetscCall(LogAsync(xin, nullptr));
799:   PetscFunctionReturn(PETSC_SUCCESS);
800: }

802: // v->ops->waxpy
803: template <device::cupm::DeviceType T>
804: inline PetscErrorCode VecSeq_CUPM<T>::WAXPYAsync(Vec win, PetscScalar alpha, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
805: {
806:   PetscBool xiscupm, yiscupm;

808:   PetscFunctionBegin;
809:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
810:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
811:   if (!xiscupm || !yiscupm) {
812:     PetscCall(VecWAXPY_Seq(win, alpha, xin, yin));
813:     PetscFunctionReturn(PETSC_SUCCESS);
814:   }
815:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
816:   if (alpha == PetscScalar(0.0)) {
817:     PetscCall(CopyAsync(yin, win, dctx));
818:   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
819:     cupmBlasHandle_t cupmBlasHandle;
820:     cupmStream_t     stream;
821:     PetscBool        xiscupm, yiscupm;

823:     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
824:     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
825:     if (!xiscupm || !yiscupm) {
826:       PetscCall(VecWAXPY_Seq(win, alpha, xin, yin));
827:       PetscFunctionReturn(PETSC_SUCCESS);
828:     }
829:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle, NULL, &stream));
830:     {
831:       const auto wptr = DeviceArrayWrite(dctx, win);

833:       PetscCall(PetscLogGpuTimeBegin());
834:       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
835:       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
836:       PetscCall(PetscLogGpuTimeEnd());
837:     }
838:     PetscCall(PetscLogGpuFlops(2 * n));
839:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
840:   }
841:   PetscFunctionReturn(PETSC_SUCCESS);
842: }

844: // v->ops->waxpy
845: template <device::cupm::DeviceType T>
846: inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
847: {
848:   PetscFunctionBegin;
849:   PetscCall(WAXPYAsync(win, alpha, xin, yin, nullptr));
850:   PetscFunctionReturn(PETSC_SUCCESS);
851: }

853: namespace kernels
854: {

856: template <typename... Args>
857: PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
858: {
859:   constexpr int      N        = sizeof...(Args);
860:   const auto         tx       = threadIdx.x;
861:   const PetscScalar *yptr_p[] = {yptr...};

863:   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];

865:   // load a to shared memory
866:   if (tx < N) aptr_shmem[tx] = aptr[tx];
867:   __syncthreads();

869:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
870:   // these may look the same but give different results!
871: #if 0
872:     PetscScalar sum = 0.0;

874:   #pragma unroll
875:     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
876:     xptr[i] += sum;
877: #else
878:     auto sum = xptr[i];

880:   #pragma unroll
881:     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j] * yptr_p[j][i];
882:     xptr[i] = sum;
883: #endif
884:   });
885:   return;
886: }

888: } // namespace kernels

890: namespace detail
891: {

893: // a helper-struct to gobble the size_t input, it is used with template parameter pack
894: // expansion such that
895: // typename repeat_type...
896: // expands to
897: // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
898: template <typename T, std::size_t>
899: struct repeat_type {
900:   using type = T;
901: };

903: } // namespace detail

905: template <device::cupm::DeviceType T>
906: template <std::size_t... Idx>
907: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
908: {
909:   PetscFunctionBegin;
910:   // clang-format off
911:   PetscCall(
912:     PetscCUPMLaunchKernel1D(
913:       size, 0, stream,
914:       kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
915:       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
916:     )
917:   );
918:   // clang-format on
919:   PetscFunctionReturn(PETSC_SUCCESS);
920: }

922: template <device::cupm::DeviceType T>
923: template <int N>
924: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
925: {
926:   PetscFunctionBegin;
927:   PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
928:   yidx += N;
929:   PetscFunctionReturn(PETSC_SUCCESS);
930: }

932: // VecMAXPYAsync_Private
933: template <device::cupm::DeviceType T>
934: inline PetscErrorCode VecSeq_CUPM<T>::MAXPYAsync(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin, PetscDeviceContext dctx) noexcept
935: {
936:   const auto   n = xin->map->n;
937:   cupmStream_t stream;
938:   PetscBool    yiscupm = PETSC_TRUE;

940:   PetscFunctionBegin;
941:   for (PetscInt i = 0; i < nv && yiscupm; i++) PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin[i]), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
942:   if (!yiscupm) {
943:     PetscCall(VecMAXPY_Seq(xin, nv, alpha, yin));
944:     PetscFunctionReturn(PETSC_SUCCESS);
945:   }
946:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
947:   PetscCall(GetHandlesFrom_(dctx, &stream));
948:   {
949:     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
950:     PetscScalar *d_alpha = nullptr;
951:     PetscInt     yidx    = 0;

953:     // placement of early-return is deliberate, we would like to capture the
954:     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
955:     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
956:     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
957:     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
958:     PetscCall(PetscLogGpuTimeBegin());
959:     do {
960:       switch (nv - yidx) {
961:       case 7:
962:         PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
963:         break;
964:       case 6:
965:         PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
966:         break;
967:       case 5:
968:         PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
969:         break;
970:       case 4:
971:         PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
972:         break;
973:       case 3:
974:         PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
975:         break;
976:       case 2:
977:         PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
978:         break;
979:       case 1:
980:         PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
981:         break;
982:       default: // 8 or more
983:         PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
984:         break;
985:       }
986:     } while (yidx < nv);
987:     PetscCall(PetscLogGpuTimeEnd());
988:     PetscCall(PetscDeviceFree(dctx, d_alpha));
989:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
990:   }
991:   PetscCall(PetscLogGpuFlops(nv * 2 * n));
992:   PetscFunctionReturn(PETSC_SUCCESS);
993: }

995: // v->ops->maxpy
996: template <device::cupm::DeviceType T>
997: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
998: {
999:   PetscFunctionBegin;
1000:   PetscCall(MAXPYAsync(xin, nv, alpha, yin, nullptr));
1001:   PetscFunctionReturn(PETSC_SUCCESS);
1002: }

1004: template <device::cupm::DeviceType T>
1005: inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept
1006: {
1007:   PetscBool yiscupm;

1009:   PetscFunctionBegin;
1010:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1011:   if (!yiscupm) {
1012:     PetscCall(VecDot_Seq(xin, yin, z));
1013:     PetscFunctionReturn(PETSC_SUCCESS);
1014:   }
1015:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1016:     PetscDeviceContext dctx;
1017:     cupmBlasHandle_t   cupmBlasHandle;

1019:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1020:     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
1021:     // second
1022:     PetscCall(PetscLogGpuTimeBegin());
1023:     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
1024:     PetscCall(PetscLogGpuTimeEnd());
1025:     PetscCall(PetscLogGpuFlops(2 * n - 1));
1026:   } else {
1027:     *z = 0.0;
1028:   }
1029:   PetscFunctionReturn(PETSC_SUCCESS);
1030: }

1032: #define MDOT_WORKGROUP_NUM  128
1033: #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM

1035: namespace kernels
1036: {

1038: PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
1039: {
1040:   const auto group_entries = (size - 1) / gridDim.x + 1;
1041:   // for very small vectors, a group should still do some work
1042:   return group_entries ? group_entries : 1;
1043: }

1045: template <typename... ConstPetscScalarPointer>
1046: PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
1047: {
1048:   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
1049:   const PetscScalar *ylocal[] = {y...};
1050:   PetscScalar        sumlocal[N];

1052:   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];

1054:   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
1055:   // types, so each of these go on separate lines...
1056:   const auto tx       = threadIdx.x;
1057:   const auto bx       = blockIdx.x;
1058:   const auto bdx      = blockDim.x;
1059:   const auto gdx      = gridDim.x;
1060:   const auto worksize = EntriesPerGroup(size);
1061:   const auto begin    = tx + bx * worksize;
1062:   const auto end      = min((bx + 1) * worksize, size);

1064: #pragma unroll
1065:   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;

1067:   for (auto i = begin; i < end; i += bdx) {
1068:     const auto xi = x[i]; // load only once from global memory!

1070: #pragma unroll
1071:     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
1072:   }

1074: #pragma unroll
1075:   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];

1077:   // parallel reduction
1078:   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
1079:     __syncthreads();
1080:     if (tx < stride) {
1081: #pragma unroll
1082:       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
1083:     }
1084:   }
1085:   // bottom N threads per block write to global memory
1086:   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
1087:   // writes to the same sections in the above loop that it is about to read from below, but
1088:   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
1089:   __syncthreads();
1090:   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
1091:   return;
1092: }

1094: namespace
1095: {

1097: PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
1098: {
1099:   int         local_i = 0;
1100:   PetscScalar local_results[8];

1102:   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
1103:   //
1104:   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
1105:   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
1106:   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
1107:   //  |  ______________________________________________________/
1108:   //  | /            <- MDOT_WORKGROUP_NUM ->
1109:   //  |/
1110:   //  +
1111:   //  v
1112:   // *-*-*
1113:   // | | | ...
1114:   // *-*-*
1115:   //
1116:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
1117:     PetscScalar z_sum = 0;

1119:     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
1120:     local_results[local_i++] = z_sum;
1121:   });
1122:   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
1123:   // may currently be reading from results
1124:   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
1125:   // Local buffer is now written to global memory
1126:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
1127:     const auto j = --local_i;

1129:     if (j >= 0) results[i] = local_results[j];
1130:   });
1131:   return;
1132: }

1134: } // namespace

1136: #if PetscDefined(USING_HCC)
1137: namespace do_not_use
1138: {

1140: inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1141: {
1142:   (void)sum_kernel;
1143: }

1145: } // namespace do_not_use
1146: #endif

1148: } // namespace kernels

1150: template <device::cupm::DeviceType T>
1151: template <std::size_t... Idx>
1152: inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
1153: {
1154:   PetscFunctionBegin;
1155:   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
1156:   // 128 blocks of 128 threads every time which may be wasteful
1157:   // clang-format off
1158:   PetscCallCUPM(
1159:     cupmLaunchKernel(
1160:       kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
1161:       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
1162:       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
1163:     )
1164:   );
1165:   // clang-format on
1166:   PetscFunctionReturn(PETSC_SUCCESS);
1167: }

1169: template <device::cupm::DeviceType T>
1170: template <int N>
1171: inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
1172: {
1173:   PetscFunctionBegin;
1174:   PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
1175:   yidx += N;
1176:   PetscFunctionReturn(PETSC_SUCCESS);
1177: }

1179: template <device::cupm::DeviceType T>
1180: inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
1181: {
1182:   // the largest possible size of a batch
1183:   constexpr PetscInt batchsize = 8;
1184:   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
1185:   // do not create substreams. Note we don't create more than 8 streams, in practice we could
1186:   // not get more parallelism with higher numbers.
1187:   const auto   num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
1188:   const auto   n               = xin->map->n;
1189:   const auto   nwork           = nv * MDOT_WORKGROUP_NUM;
1190:   PetscScalar *d_results;
1191:   cupmStream_t stream;

1193:   PetscFunctionBegin;
1194:   PetscCall(GetHandlesFrom_(dctx, &stream));
1195:   // allocate scratchpad memory for the results of individual work groups
1196:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
1197:   {
1198:     const auto          xptr       = DeviceArrayRead(dctx, xin);
1199:     PetscInt            yidx       = 0;
1200:     auto                subidx     = 0;
1201:     auto                cur_stream = stream;
1202:     auto                cur_ctx    = dctx;
1203:     PetscDeviceContext *sub        = nullptr;
1204:     PetscStreamType     stype;

1206:     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
1207:     // sub. Ideally the parent context should also join in on the fork, but it is extremely
1208:     // fiddly to do so presently
1209:     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
1210:     if (stype == PETSC_STREAM_DEFAULT || stype == PETSC_STREAM_DEFAULT_WITH_BARRIER) stype = PETSC_STREAM_NONBLOCKING;
1211:     // If we have a default stream create nonblocking streams instead (as we can
1212:     // locally exploit the parallelism). Otherwise use the prescribed stream type.
1213:     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
1214:     PetscCall(PetscLogGpuTimeBegin());
1215:     do {
1216:       if (num_sub_streams) {
1217:         cur_ctx = sub[subidx++ % num_sub_streams];
1218:         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
1219:       }
1220:       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
1221:       // it is very likely better to do 4+5 rather than 8+1
1222:       switch (nv - yidx) {
1223:       case 7:
1224:         PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1225:         break;
1226:       case 6:
1227:         PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1228:         break;
1229:       case 5:
1230:         PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1231:         break;
1232:       case 4:
1233:         PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1234:         break;
1235:       case 3:
1236:         PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1237:         break;
1238:       case 2:
1239:         PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1240:         break;
1241:       case 1:
1242:         PetscCall(MDot_kernel_dispatch_<1>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1243:         break;
1244:       default: // 8 or more
1245:         PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1246:         break;
1247:       }
1248:     } while (yidx < nv);
1249:     PetscCall(PetscLogGpuTimeEnd());
1250:     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
1251:   }

1253:   PetscCall(PetscCUPMLaunchKernel1D(nv, 0, stream, kernels::sum_kernel, nv, d_results));
1254:   // copy result of device reduction to host
1255:   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nv, cupmMemcpyDeviceToHost, stream));
1256:   // do these now while final reduction is in flight
1257:   PetscCall(PetscLogGpuFlops(nwork));
1258:   PetscCall(PetscDeviceFree(dctx, d_results));
1259:   PetscFunctionReturn(PETSC_SUCCESS);
1260: }

1262: #undef MDOT_WORKGROUP_NUM
1263: #undef MDOT_WORKGROUP_SIZE

1265: template <device::cupm::DeviceType T>
1266: inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
1267: {
1268:   // probably not worth it to run more than 8 of these at a time?
1269:   const auto          n_sub = PetscMin(nv, 8);
1270:   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
1271:   const auto          xptr  = DeviceArrayRead(dctx, xin);
1272:   PetscScalar        *d_z;
1273:   PetscDeviceContext *subctx;
1274:   cupmStream_t        stream;

1276:   PetscFunctionBegin;
1277:   PetscCall(GetHandlesFrom_(dctx, &stream));
1278:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
1279:   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
1280:   PetscCall(PetscLogGpuTimeBegin());
1281:   for (PetscInt i = 0; i < nv; ++i) {
1282:     const auto            sub = subctx[i % n_sub];
1283:     cupmBlasHandle_t      handle;
1284:     cupmBlasPointerMode_t old_mode;

1286:     PetscCall(GetHandlesFrom_(sub, &handle));
1287:     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1288:     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1289:     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1290:     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1291:   }
1292:   PetscCall(PetscLogGpuTimeEnd());
1293:   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1294:   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1295:   PetscCall(PetscDeviceFree(dctx, d_z));
1296:   // REVIEW ME: flops?????
1297:   PetscFunctionReturn(PETSC_SUCCESS);
1298: }

1300: // v->ops->mdot
1301: template <device::cupm::DeviceType T>
1302: inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1303: {
1304:   PetscFunctionBegin;
1305:   if (PetscUnlikely(nv == 1)) {
1306:     // dot handles nv = 0 correctly
1307:     PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z));
1308:   } else if (const auto n = xin->map->n) {
1309:     PetscDeviceContext dctx;

1311:     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1312:     PetscCall(GetHandles_(&dctx));
1313:     PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1314:     // REVIEW ME: double count of flops??
1315:     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1316:     PetscCall(PetscDeviceContextSynchronize(dctx));
1317:   } else {
1318:     PetscCall(PetscArrayzero(z, nv));
1319:   }
1320:   PetscFunctionReturn(PETSC_SUCCESS);
1321: }

1323: // VecSetAsync_Private
1324: template <device::cupm::DeviceType T>
1325: inline PetscErrorCode VecSeq_CUPM<T>::SetAsync(Vec xin, PetscScalar alpha, PetscDeviceContext dctx) noexcept
1326: {
1327:   const auto   n = xin->map->n;
1328:   cupmStream_t stream;

1330:   PetscFunctionBegin;
1331:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1332:   PetscCall(GetHandlesFrom_(dctx, &stream));
1333:   {
1334:     const auto xptr = DeviceArrayWrite(dctx, xin);

1336:     if (alpha == PetscScalar(0.0)) {
1337:       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1338:     } else {
1339:       const auto dptr = thrust::device_pointer_cast(xptr.data());

1341:       PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1342:     }
1343:   }
1344:   if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1345:   PetscFunctionReturn(PETSC_SUCCESS);
1346: }

1348: // v->ops->set
1349: template <device::cupm::DeviceType T>
1350: inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept
1351: {
1352:   PetscFunctionBegin;
1353:   PetscCall(SetAsync(xin, alpha, nullptr));
1354:   PetscFunctionReturn(PETSC_SUCCESS);
1355: }

1357: // VecScaleAsync_Private
1358: template <device::cupm::DeviceType T>
1359: inline PetscErrorCode VecSeq_CUPM<T>::ScaleAsync(Vec xin, PetscScalar alpha, PetscDeviceContext dctx) noexcept
1360: {
1361:   PetscFunctionBegin;
1362:   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1363:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1364:   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1365:     PetscCall(SetAsync(xin, alpha, dctx));
1366:   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1367:     cupmBlasHandle_t cupmBlasHandle;

1369:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1370:     PetscCall(PetscLogGpuTimeBegin());
1371:     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1372:     PetscCall(PetscLogGpuTimeEnd());
1373:     PetscCall(PetscLogGpuFlops(n));
1374:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1375:   } else {
1376:     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1377:   }
1378:   PetscFunctionReturn(PETSC_SUCCESS);
1379: }

1381: // v->ops->scale
1382: template <device::cupm::DeviceType T>
1383: inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept
1384: {
1385:   PetscFunctionBegin;
1386:   PetscCall(ScaleAsync(xin, alpha, nullptr));
1387:   PetscFunctionReturn(PETSC_SUCCESS);
1388: }

1390: // v->ops->tdot
1391: template <device::cupm::DeviceType T>
1392: inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept
1393: {
1394:   PetscBool yiscupm;

1396:   PetscFunctionBegin;
1397:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1398:   if (!yiscupm) {
1399:     PetscCall(VecTDot_Seq(xin, yin, z));
1400:     PetscFunctionReturn(PETSC_SUCCESS);
1401:   }
1402:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1403:     PetscDeviceContext dctx;
1404:     cupmBlasHandle_t   cupmBlasHandle;

1406:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1407:     PetscCall(PetscLogGpuTimeBegin());
1408:     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1409:     PetscCall(PetscLogGpuTimeEnd());
1410:     PetscCall(PetscLogGpuFlops(2 * n - 1));
1411:   } else {
1412:     *z = 0.0;
1413:   }
1414:   PetscFunctionReturn(PETSC_SUCCESS);
1415: }

1417: // VecCopyAsync_Private
1418: template <device::cupm::DeviceType T>
1419: inline PetscErrorCode VecSeq_CUPM<T>::CopyAsync(Vec xin, Vec yout, PetscDeviceContext dctx) noexcept
1420: {
1421:   PetscFunctionBegin;
1422:   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1423:   if (const auto n = xin->map->n) {
1424:     const auto xmask = xin->offloadmask;
1425:     // silence buggy gcc warning: mode may be used uninitialized in this function
1426:     auto         mode = cupmMemcpyDeviceToDevice;
1427:     cupmStream_t stream;

1429:     // translate from PetscOffloadMask to cupmMemcpyKind
1430:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1431:     switch (const auto ymask = yout->offloadmask) {
1432:     case PETSC_OFFLOAD_UNALLOCATED: {
1433:       PetscBool yiscupm;

1435:       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1436:       if (yiscupm) {
1437:         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1438:         break;
1439:       }
1440:     } // fall-through if unallocated and not cupm
1441: #if PETSC_CPP_VERSION >= 17
1442:       [[fallthrough]];
1443: #endif
1444:     case PETSC_OFFLOAD_CPU: {
1445:       PetscBool yiscupm;

1447:       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1448:       if (yiscupm) {
1449:         mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToDevice : cupmMemcpyDeviceToDevice;
1450:       } else {
1451:         mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1452:       }
1453:       break;
1454:     }
1455:     case PETSC_OFFLOAD_BOTH:
1456:     case PETSC_OFFLOAD_GPU:
1457:       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1458:       break;
1459:     default:
1460:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1461:     }

1463:     PetscCall(GetHandlesFrom_(dctx, &stream));
1464:     switch (mode) {
1465:     case cupmMemcpyDeviceToDevice: // the best case
1466:     case cupmMemcpyHostToDevice: { // not terrible
1467:       const auto yptr = DeviceArrayWrite(dctx, yout);
1468:       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();

1470:       PetscCall(PetscLogGpuTimeBegin());
1471:       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1472:       PetscCall(PetscLogGpuTimeEnd());
1473:     } break;
1474:     case cupmMemcpyDeviceToHost: // not great
1475:     case cupmMemcpyHostToHost: { // worst case
1476:       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1477:       PetscScalar *yptr;

1479:       PetscCall(VecGetArrayWrite(yout, &yptr));
1480:       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1481:       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1482:       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1483:       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1484:     } break;
1485:     default:
1486:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1487:     }
1488:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1489:   } else {
1490:     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1491:   }
1492:   PetscFunctionReturn(PETSC_SUCCESS);
1493: }

1495: // v->ops->copy
1496: template <device::cupm::DeviceType T>
1497: inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept
1498: {
1499:   PetscFunctionBegin;
1500:   PetscCall(CopyAsync(xin, yout, nullptr));
1501:   PetscFunctionReturn(PETSC_SUCCESS);
1502: }

1504: // VecSwapAsync_Private
1505: template <device::cupm::DeviceType T>
1506: inline PetscErrorCode VecSeq_CUPM<T>::SwapAsync(Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
1507: {
1508:   PetscBool yiscupm;

1510:   PetscFunctionBegin;
1511:   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1512:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1513:   PetscCheck(yiscupm, PetscObjectComm(PetscObjectCast(yin)), PETSC_ERR_SUP, "Cannot swap with Y of type %s", PetscObjectCast(yin)->type_name);
1514:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1515:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1516:     cupmBlasHandle_t cupmBlasHandle;

1518:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1519:     PetscCall(PetscLogGpuTimeBegin());
1520:     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1521:     PetscCall(PetscLogGpuTimeEnd());
1522:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1523:   } else {
1524:     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1525:     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1526:   }
1527:   PetscFunctionReturn(PETSC_SUCCESS);
1528: }

1530: // v->ops->swap
1531: template <device::cupm::DeviceType T>
1532: inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept
1533: {
1534:   PetscFunctionBegin;
1535:   PetscCall(SwapAsync(xin, yin, nullptr));
1536:   PetscFunctionReturn(PETSC_SUCCESS);
1537: }

1539: // VecAXPYBYAsync_Private
1540: template <device::cupm::DeviceType T>
1541: inline PetscErrorCode VecSeq_CUPM<T>::AXPBYAsync(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin, PetscDeviceContext dctx) noexcept
1542: {
1543:   PetscBool xiscupm;

1545:   PetscFunctionBegin;
1546:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1547:   if (!xiscupm) {
1548:     PetscCall(VecAXPBY_Seq(yin, alpha, beta, xin));
1549:     PetscFunctionReturn(PETSC_SUCCESS);
1550:   }
1551:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1552:   if (alpha == PetscScalar(0.0)) {
1553:     PetscCall(ScaleAsync(yin, beta, dctx));
1554:   } else if (beta == PetscScalar(1.0)) {
1555:     PetscCall(AXPYAsync(yin, alpha, xin, dctx));
1556:   } else if (alpha == PetscScalar(1.0)) {
1557:     PetscCall(AYPXAsync(yin, beta, xin, dctx));
1558:   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1559:     PetscBool xiscupm;

1561:     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1562:     if (!xiscupm) {
1563:       PetscCall(VecAXPBY_Seq(yin, alpha, beta, xin));
1564:       PetscFunctionReturn(PETSC_SUCCESS);
1565:     }

1567:     const auto       betaIsZero = beta == PetscScalar(0.0);
1568:     const auto       aptr       = cupmScalarPtrCast(&alpha);
1569:     cupmBlasHandle_t cupmBlasHandle;

1571:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1572:     {
1573:       const auto xptr = DeviceArrayRead(dctx, xin);

1575:       if (betaIsZero /* beta = 0 */) {
1576:         // here we can get away with purely write-only as we memcpy into it first
1577:         const auto   yptr = DeviceArrayWrite(dctx, yin);
1578:         cupmStream_t stream;

1580:         PetscCall(GetHandlesFrom_(dctx, &stream));
1581:         PetscCall(PetscLogGpuTimeBegin());
1582:         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1583:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1584:       } else {
1585:         const auto yptr = DeviceArrayReadWrite(dctx, yin);

1587:         PetscCall(PetscLogGpuTimeBegin());
1588:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1589:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1590:       }
1591:     }
1592:     PetscCall(PetscLogGpuTimeEnd());
1593:     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1594:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1595:   } else {
1596:     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1597:   }
1598:   PetscFunctionReturn(PETSC_SUCCESS);
1599: }

1601: // v->ops->axpby
1602: template <device::cupm::DeviceType T>
1603: inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1604: {
1605:   PetscFunctionBegin;
1606:   PetscCall(AXPBYAsync(yin, alpha, beta, xin, nullptr));
1607:   PetscFunctionReturn(PETSC_SUCCESS);
1608: }

1610: // VecAXPBYPCZAsync_Private
1611: template <device::cupm::DeviceType T>
1612: inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZAsync(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
1613: {
1614:   PetscFunctionBegin;
1615:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1616:   if (gamma != PetscScalar(1.0)) PetscCall(ScaleAsync(zin, gamma, dctx));
1617:   PetscCall(AXPYAsync(zin, alpha, xin, dctx));
1618:   PetscCall(AXPYAsync(zin, beta, yin, dctx));
1619:   PetscFunctionReturn(PETSC_SUCCESS);
1620: }

1622: // v->ops->axpbypcz
1623: template <device::cupm::DeviceType T>
1624: inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1625: {
1626:   PetscFunctionBegin;
1627:   PetscCall(AXPBYPCZAsync(zin, alpha, beta, gamma, xin, yin, nullptr));
1628:   PetscFunctionReturn(PETSC_SUCCESS);
1629: }

1631: // v->ops->norm
1632: template <device::cupm::DeviceType T>
1633: inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept
1634: {
1635:   PetscDeviceContext dctx;
1636:   cupmBlasHandle_t   cupmBlasHandle;

1638:   PetscFunctionBegin;
1639:   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1640:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1641:     const auto xptr      = DeviceArrayRead(dctx, xin);
1642:     PetscInt   flopCount = 0;

1644:     PetscCall(PetscLogGpuTimeBegin());
1645:     switch (type) {
1646:     case NORM_1_AND_2:
1647:     case NORM_1:
1648:       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1649:       flopCount = std::max(n - 1, 0);
1650:       if (type == NORM_1) break;
1651:       ++z; // fall-through
1652: #if PETSC_CPP_VERSION >= 17
1653:       [[fallthrough]];
1654: #endif
1655:     case NORM_2:
1656:     case NORM_FROBENIUS:
1657:       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1658:       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1659:       break;
1660:     case NORM_INFINITY: {
1661:       cupmBlasInt_t max_loc = 0;
1662:       PetscScalar   xv      = 0.;
1663:       cupmStream_t  stream;

1665:       PetscCall(GetHandlesFrom_(dctx, &stream));
1666:       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1667:       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1668:       *z = PetscAbsScalar(xv);
1669:       // REVIEW ME: flopCount = ???
1670:     } break;
1671:     }
1672:     PetscCall(PetscLogGpuTimeEnd());
1673:     PetscCall(PetscLogGpuFlops(flopCount));
1674:   } else {
1675:     z[0]                    = 0.0;
1676:     z[type == NORM_1_AND_2] = 0.0;
1677:   }
1678:   PetscFunctionReturn(PETSC_SUCCESS);
1679: }

1681: namespace detail
1682: {

1684: template <NormType wnormtype>
1685: class ErrorWNormTransformBase {
1686: public:
1687:   using result_type = thrust::tuple<PetscReal, PetscReal, PetscReal, PetscInt, PetscInt, PetscInt>;

1689:   constexpr explicit ErrorWNormTransformBase(PetscReal v) noexcept : ignore_max_{v} { }

1691: protected:
1692:   struct NormTuple {
1693:     PetscReal norm;
1694:     PetscInt  loc;
1695:   };

1697:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL static NormTuple compute_norm_(PetscReal err, PetscReal tol) noexcept
1698:   {
1699:     if (tol > 0.) {
1700:       const auto val = err / tol;

1702:       return {wnormtype == NORM_INFINITY ? val : PetscSqr(val), 1};
1703:     } else {
1704:       return {0.0, 0};
1705:     }
1706:   }

1708:   PetscReal ignore_max_;
1709: };

1711: template <NormType wnormtype>
1712: struct ErrorWNormTransform : ErrorWNormTransformBase<wnormtype> {
1713:   using base_type     = ErrorWNormTransformBase<wnormtype>;
1714:   using result_type   = typename base_type::result_type;
1715:   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>;

1717:   using base_type::base_type;

1719:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1720:   {
1721:     const auto u     = thrust::get<0>(x); // with x.get<0>(), cuda-12.4.0 gives error: class "cuda::std::__4::tuple" has no member "get"
1722:     const auto y     = thrust::get<1>(x);
1723:     const auto au    = PetscAbsScalar(u);
1724:     const auto ay    = PetscAbsScalar(y);
1725:     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1726:     const auto tola  = skip ? 0.0 : PetscRealPart(thrust::get<2>(x));
1727:     const auto tolr  = skip ? 0.0 : PetscRealPart(thrust::get<3>(x)) * PetscMax(au, ay);
1728:     const auto tol   = tola + tolr;
1729:     const auto err   = PetscAbsScalar(u - y);
1730:     const auto tup_a = this->compute_norm_(err, tola);
1731:     const auto tup_r = this->compute_norm_(err, tolr);
1732:     const auto tup_n = this->compute_norm_(err, tol);

1734:     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1735:   }
1736: };

1738: template <NormType wnormtype>
1739: struct ErrorWNormETransform : ErrorWNormTransformBase<wnormtype> {
1740:   using base_type     = ErrorWNormTransformBase<wnormtype>;
1741:   using result_type   = typename base_type::result_type;
1742:   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar, PetscScalar>;

1744:   using base_type::base_type;

1746:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1747:   {
1748:     const auto au    = PetscAbsScalar(thrust::get<0>(x));
1749:     const auto ay    = PetscAbsScalar(thrust::get<1>(x));
1750:     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1751:     const auto tola  = skip ? 0.0 : PetscRealPart(thrust::get<3>(x));
1752:     const auto tolr  = skip ? 0.0 : PetscRealPart(thrust::get<4>(x)) * PetscMax(au, ay);
1753:     const auto tol   = tola + tolr;
1754:     const auto err   = PetscAbsScalar(thrust::get<2>(x));
1755:     const auto tup_a = this->compute_norm_(err, tola);
1756:     const auto tup_r = this->compute_norm_(err, tolr);
1757:     const auto tup_n = this->compute_norm_(err, tol);

1759:     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1760:   }
1761: };

1763: template <NormType wnormtype>
1764: struct ErrorWNormReduce {
1765:   using value_type = typename ErrorWNormTransformBase<wnormtype>::result_type;

1767:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept
1768:   {
1769:     // cannot use lhs.get<0>() etc since the using decl above ambiguates the fact that
1770:     // result_type is a template, so in order to fix this we would need to write:
1771:     //
1772:     // lhs.template get<0>()
1773:     //
1774:     // which is unseemly.
1775:     if (wnormtype == NORM_INFINITY) {
1776:       // clang-format off
1777:       return {
1778:         PetscMax(thrust::get<0>(lhs), thrust::get<0>(rhs)),
1779:         PetscMax(thrust::get<1>(lhs), thrust::get<1>(rhs)),
1780:         PetscMax(thrust::get<2>(lhs), thrust::get<2>(rhs)),
1781:         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1782:         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1783:         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1784:       };
1785:       // clang-format on
1786:     } else {
1787:       // clang-format off
1788:       return {
1789:         thrust::get<0>(lhs) + thrust::get<0>(rhs),
1790:         thrust::get<1>(lhs) + thrust::get<1>(rhs),
1791:         thrust::get<2>(lhs) + thrust::get<2>(rhs),
1792:         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1793:         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1794:         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1795:       };
1796:       // clang-format on
1797:     }
1798:   }
1799: };

1801: template <template <NormType> class WNormTransformType, typename Tuple, typename cupmStream_t>
1802: inline PetscErrorCode ExecuteWNorm(Tuple &&first, Tuple &&last, NormType wnormtype, cupmStream_t stream, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1803: {
1804:   auto      begin = thrust::make_zip_iterator(std::forward<Tuple>(first));
1805:   auto      end   = thrust::make_zip_iterator(std::forward<Tuple>(last));
1806:   PetscReal n = 0, na = 0, nr = 0;
1807:   PetscInt  n_loc = 0, na_loc = 0, nr_loc = 0;

1809:   PetscFunctionBegin;
1810:   // clang-format off
1811:   if (wnormtype == NORM_INFINITY) {
1812:     PetscCallThrust(
1813:       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1814:         thrust::transform_reduce,
1815:         stream,
1816:         std::move(begin),
1817:         std::move(end),
1818:         WNormTransformType<NORM_INFINITY>{ignore_max},
1819:         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1820:         ErrorWNormReduce<NORM_INFINITY>{}
1821:       )
1822:     );
1823:   } else {
1824:     PetscCallThrust(
1825:       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1826:         thrust::transform_reduce,
1827:         stream,
1828:         std::move(begin),
1829:         std::move(end),
1830:         WNormTransformType<NORM_2>{ignore_max},
1831:         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1832:         ErrorWNormReduce<NORM_2>{}
1833:       )
1834:     );
1835:   }
1836:   // clang-format on
1837:   if (wnormtype == NORM_2) {
1838:     *norm  = PetscSqrtReal(*norm);
1839:     *norma = PetscSqrtReal(*norma);
1840:     *normr = PetscSqrtReal(*normr);
1841:   }
1842:   PetscFunctionReturn(PETSC_SUCCESS);
1843: }

1845: } // namespace detail

1847: // v->ops->errorwnorm
1848: template <device::cupm::DeviceType T>
1849: inline PetscErrorCode VecSeq_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1850: {
1851:   const auto         nl  = U->map->n;
1852:   auto               ait = thrust::make_constant_iterator(static_cast<PetscScalar>(atol));
1853:   auto               rit = thrust::make_constant_iterator(static_cast<PetscScalar>(rtol));
1854:   PetscDeviceContext dctx;
1855:   cupmStream_t       stream;

1857:   PetscFunctionBegin;
1858:   PetscCall(GetHandles_(&dctx, &stream));
1859:   {
1860:     const auto ConditionalDeviceArrayRead = [&](Vec v) {
1861:       if (v) {
1862:         return thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1863:       } else {
1864:         return thrust::device_ptr<PetscScalar>{nullptr};
1865:       }
1866:     };

1868:     const auto uarr = DeviceArrayRead(dctx, U);
1869:     const auto yarr = DeviceArrayRead(dctx, Y);
1870:     const auto uptr = thrust::device_pointer_cast(uarr.data());
1871:     const auto yptr = thrust::device_pointer_cast(yarr.data());
1872:     const auto eptr = ConditionalDeviceArrayRead(E);
1873:     const auto rptr = ConditionalDeviceArrayRead(vrtol);
1874:     const auto aptr = ConditionalDeviceArrayRead(vatol);

1876:     if (!vatol && !vrtol) {
1877:       if (E) {
1878:         // clang-format off
1879:         PetscCall(
1880:           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1881:             thrust::make_tuple(uptr, yptr, eptr, ait, rit),
1882:             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rit),
1883:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1884:           )
1885:         );
1886:         // clang-format on
1887:       } else {
1888:         // clang-format off
1889:         PetscCall(
1890:           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1891:             thrust::make_tuple(uptr, yptr, ait, rit),
1892:             thrust::make_tuple(uptr + nl, yptr + nl, ait, rit),
1893:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1894:           )
1895:         );
1896:         // clang-format on
1897:       }
1898:     } else if (!vatol) {
1899:       if (E) {
1900:         // clang-format off
1901:         PetscCall(
1902:           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1903:             thrust::make_tuple(uptr, yptr, eptr, ait, rptr),
1904:             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rptr + nl),
1905:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1906:           )
1907:         );
1908:         // clang-format on
1909:       } else {
1910:         // clang-format off
1911:         PetscCall(
1912:           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1913:             thrust::make_tuple(uptr, yptr, ait, rptr),
1914:             thrust::make_tuple(uptr + nl, yptr + nl, ait, rptr + nl),
1915:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1916:           )
1917:         );
1918:         // clang-format on
1919:       }
1920:     } else if (!vrtol) {
1921:       if (E) {
1922:         // clang-format off
1923:           PetscCall(
1924:             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1925:               thrust::make_tuple(uptr, yptr, eptr, aptr, rit),
1926:               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rit),
1927:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1928:             )
1929:           );
1930:         // clang-format on
1931:       } else {
1932:         // clang-format off
1933:           PetscCall(
1934:             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1935:               thrust::make_tuple(uptr, yptr, aptr, rit),
1936:               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rit),
1937:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1938:             )
1939:           );
1940:         // clang-format on
1941:       }
1942:     } else {
1943:       if (E) {
1944:         // clang-format off
1945:           PetscCall(
1946:             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1947:               thrust::make_tuple(uptr, yptr, eptr, aptr, rptr),
1948:               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rptr + nl),
1949:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1950:             )
1951:           );
1952:         // clang-format on
1953:       } else {
1954:         // clang-format off
1955:           PetscCall(
1956:             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1957:               thrust::make_tuple(uptr, yptr, aptr, rptr),
1958:               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rptr + nl),
1959:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1960:             )
1961:           );
1962:         // clang-format on
1963:       }
1964:     }
1965:   }
1966:   PetscFunctionReturn(PETSC_SUCCESS);
1967: }

1969: namespace detail
1970: {
1971: struct dotnorm2_mult {
1972:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1973:   {
1974:     const auto conjt = PetscConj(t);

1976:     return {s * conjt, t * conjt};
1977:   }
1978: };

1980: // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1981: // would do it myself but now I am worried that they do so on purpose...
1982: struct dotnorm2_tuple_plus {
1983:   using value_type = thrust::tuple<PetscScalar, PetscScalar>;

1985:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {thrust::get<0>(lhs) + thrust::get<0>(rhs), thrust::get<1>(lhs) + thrust::get<1>(rhs)}; }
1986: };

1988: } // namespace detail

1990: // v->ops->dotnorm2
1991: template <device::cupm::DeviceType T>
1992: inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1993: {
1994:   PetscDeviceContext dctx;
1995:   cupmStream_t       stream;

1997:   PetscFunctionBegin;
1998:   PetscCall(GetHandles_(&dctx, &stream));
1999:   {
2000:     PetscScalar dpt = 0.0, nmt = 0.0;
2001:     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());

2003:     // clang-format off
2004:     PetscCallThrust(
2005:       thrust::tie(*dp, *nm) = THRUST_CALL(
2006:         thrust::inner_product,
2007:         stream,
2008:         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
2009:         thrust::make_tuple(dpt, nmt),
2010:         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
2011:       );
2012:     );
2013:     // clang-format on
2014:   }
2015:   PetscFunctionReturn(PETSC_SUCCESS);
2016: }

2018: namespace detail
2019: {
2020: struct conjugate {
2021:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &x) const noexcept { return PetscConj(x); }
2022: };

2024: } // namespace detail

2026: // v->ops->conjugate
2027: template <device::cupm::DeviceType T>
2028: inline PetscErrorCode VecSeq_CUPM<T>::ConjugateAsync(Vec xin, PetscDeviceContext dctx) noexcept
2029: {
2030:   PetscFunctionBegin;
2031:   if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin, nullptr, dctx));
2032:   PetscFunctionReturn(PETSC_SUCCESS);
2033: }

2035: // v->ops->conjugate
2036: template <device::cupm::DeviceType T>
2037: inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept
2038: {
2039:   PetscFunctionBegin;
2040:   PetscCall(ConjugateAsync(xin, nullptr));
2041:   PetscFunctionReturn(PETSC_SUCCESS);
2042: }

2044: namespace detail
2045: {

2047: struct real_part {
2048:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const noexcept { return {PetscRealPart(thrust::get<0>(x)), thrust::get<1>(x)}; }

2050:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(const PetscScalar &x) const noexcept { return PetscRealPart(x); }
2051: };

2053: // deriving from Operator allows us to "store" an instance of the operator in the class but
2054: // also take advantage of empty base class optimization if the operator is stateless
2055: template <typename Operator>
2056: class tuple_compare : Operator {
2057: public:
2058:   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
2059:   using operator_type = Operator;

2061:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
2062:   {
2063:     if (op_()(thrust::get<0>(y), thrust::get<0>(x))) {
2064:       // if y is strictly greater/less than x, return y
2065:       return y;
2066:     } else if (thrust::get<0>(y) == thrust::get<0>(x)) {
2067:       // if equal, prefer lower index
2068:       return thrust::get<1>(y) < thrust::get<1>(x) ? y : x;
2069:     }
2070:     // otherwise return x
2071:     return x;
2072:   }

2074: private:
2075:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
2076: };

2078: } // namespace detail

2080: template <device::cupm::DeviceType T>
2081: template <typename TupleFuncT, typename UnaryFuncT>
2082: inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
2083: {
2084:   PetscFunctionBegin;
2085:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
2086:   if (p) *p = -1;
2087:   if (const auto n = v->map->n) {
2088:     PetscDeviceContext dctx;
2089:     cupmStream_t       stream;

2091:     PetscCall(GetHandles_(&dctx, &stream));
2092:     // needed to:
2093:     // 1. switch between transform_reduce and reduce
2094:     // 2. strip the real_part functor from the arguments
2095: #if PetscDefined(USE_COMPLEX)
2096:   #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
2097: #else
2098:   #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
2099: #endif
2100:     {
2101:       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());

2103:       if (p) {
2104:         // clang-format off
2105:         const auto zip = thrust::make_zip_iterator(
2106:           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
2107:         );
2108:         // clang-format on
2109:         // need to use preprocessor conditionals since otherwise thrust complains about not being
2110:         // able to convert a thrust::device_reference to a PetscReal on complex
2111:         // builds...
2112:         // clang-format off
2113:         PetscCallThrust(
2114:           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
2115:             stream, zip, zip + n, detail::real_part{},
2116:             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
2117:           );
2118:         );
2119:         // clang-format on
2120:       } else {
2121:         // clang-format off
2122:         PetscCallThrust(
2123:           *m = THRUST_MINMAX_REDUCE(
2124:             stream, vptr, vptr + n, detail::real_part{},
2125:             *m, std::forward<UnaryFuncT>(unary_ftr)
2126:           );
2127:         );
2128:         // clang-format on
2129:       }
2130:     }
2131: #undef THRUST_MINMAX_REDUCE
2132:   }
2133:   // REVIEW ME: flops?
2134:   PetscFunctionReturn(PETSC_SUCCESS);
2135: }

2137: // v->ops->max
2138: template <device::cupm::DeviceType T>
2139: inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept
2140: {
2141:   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
2142:   using unary_functor = thrust::maximum<PetscReal>;

2144:   PetscFunctionBegin;
2145:   *m = PETSC_MIN_REAL;
2146:   // use {} constructor syntax otherwise most vexing parse
2147:   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
2148:   PetscFunctionReturn(PETSC_SUCCESS);
2149: }

2151: // v->ops->min
2152: template <device::cupm::DeviceType T>
2153: inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept
2154: {
2155:   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
2156:   using unary_functor = thrust::minimum<PetscReal>;

2158:   PetscFunctionBegin;
2159:   *m = PETSC_MAX_REAL;
2160:   // use {} constructor syntax otherwise most vexing parse
2161:   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
2162:   PetscFunctionReturn(PETSC_SUCCESS);
2163: }

2165: // v->ops->sum
2166: template <device::cupm::DeviceType T>
2167: inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept
2168: {
2169:   PetscFunctionBegin;
2170:   if (const auto n = v->map->n) {
2171:     PetscDeviceContext dctx;
2172:     cupmStream_t       stream;

2174:     PetscCall(GetHandles_(&dctx, &stream));
2175:     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
2176:     // REVIEW ME: why not cupmBlasXasum()?
2177:     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
2178:     // REVIEW ME: must be at least n additions
2179:     PetscCall(PetscLogGpuFlops(n));
2180:   } else {
2181:     *sum = 0.0;
2182:   }
2183:   PetscFunctionReturn(PETSC_SUCCESS);
2184: }

2186: template <device::cupm::DeviceType T>
2187: inline PetscErrorCode VecSeq_CUPM<T>::ShiftAsync(Vec v, PetscScalar shift, PetscDeviceContext dctx) noexcept
2188: {
2189:   PetscFunctionBegin;
2190:   PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v, nullptr, dctx));
2191:   PetscFunctionReturn(PETSC_SUCCESS);
2192: }

2194: template <device::cupm::DeviceType T>
2195: inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept
2196: {
2197:   PetscFunctionBegin;
2198:   PetscCall(ShiftAsync(v, shift, nullptr));
2199:   PetscFunctionReturn(PETSC_SUCCESS);
2200: }

2202: template <device::cupm::DeviceType T>
2203: inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept
2204: {
2205:   PetscFunctionBegin;
2206:   if (const auto n = v->map->n) {
2207:     PetscBool          iscurand;
2208:     PetscDeviceContext dctx;

2210:     PetscCall(GetHandles_(&dctx));
2211:     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
2212:     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
2213:     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
2214:   } else {
2215:     PetscCall(MaybeIncrementEmptyLocalVec(v));
2216:   }
2217:   // REVIEW ME: flops????
2218:   // REVIEW ME: Timing???
2219:   PetscFunctionReturn(PETSC_SUCCESS);
2220: }

2222: // v->ops->setpreallocation
2223: template <device::cupm::DeviceType T>
2224: inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
2225: {
2226:   PetscDeviceContext dctx;

2228:   PetscFunctionBegin;
2229:   PetscCall(GetHandles_(&dctx));
2230:   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
2231:   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
2232:   PetscFunctionReturn(PETSC_SUCCESS);
2233: }

2235: // v->ops->setvaluescoo
2236: template <device::cupm::DeviceType T>
2237: inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
2238: {
2239:   auto               vv = const_cast<PetscScalar *>(v);
2240:   PetscMemType       memtype;
2241:   PetscDeviceContext dctx;
2242:   cupmStream_t       stream;

2244:   PetscFunctionBegin;
2245:   PetscCall(GetHandles_(&dctx, &stream));
2246:   PetscCall(PetscGetMemType(v, &memtype));
2247:   if (PetscMemTypeHost(memtype)) {
2248:     const auto size = VecIMPLCast(x)->coo_n;

2250:     // If user gave v[] in host, we might need to copy it to device if any
2251:     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
2252:     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
2253:   }

2255:   if (const auto n = x->map->n) {
2256:     const auto vcu = VecCUPMCast(x);

2258:     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
2259:   } else {
2260:     PetscCall(MaybeIncrementEmptyLocalVec(x));
2261:   }

2263:   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
2264:   PetscCall(PetscDeviceContextSynchronize(dctx));
2265:   PetscFunctionReturn(PETSC_SUCCESS);
2266: }

2268: } // namespace impl

2270: } // namespace cupm

2272: } // namespace vec

2274: } // namespace Petsc