Actual source code: vecseqcupm_impl.hpp

  1: #pragma once

  3: #include "vecseqcupm.hpp"

  5: #include <petsc/private/randomimpl.h>

  7: #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
  8: #include "../src/sys/objects/device/impls/cupm/kernels.hpp"

 10: #if PetscDefined(USE_COMPLEX)
 11:   #include <thrust/transform_reduce.h>
 12: #endif
 13: #include <thrust/transform.h>
 14: #include <thrust/reduce.h>
 15: #include <thrust/functional.h>
 16: #include <thrust/tuple.h>
 17: #include <thrust/device_ptr.h>
 18: #include <thrust/iterator/zip_iterator.h>
 19: #include <thrust/iterator/counting_iterator.h>
 20: #include <thrust/iterator/constant_iterator.h>
 21: #include <thrust/inner_product.h>

 23: namespace Petsc
 24: {

 26: namespace vec
 27: {

 29: namespace cupm
 30: {

 32: namespace impl
 33: {

 35: // ==========================================================================================
 36: // VecSeq_CUPM - Private API
 37: // ==========================================================================================

 39: template <device::cupm::DeviceType T>
 40: inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
 41: {
 42:   return static_cast<Vec_Seq *>(v->data);
 43: }

 45: template <device::cupm::DeviceType T>
 46: inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
 47: {
 48:   return VECSEQCUPM();
 49: }

 51: template <device::cupm::DeviceType T>
 52: inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept
 53: {
 54:   return VECSEQ;
 55: }

 57: template <device::cupm::DeviceType T>
 58: inline PetscErrorCode VecSeq_CUPM<T>::ClearAsyncFunctions(Vec v) noexcept
 59: {
 60:   PetscFunctionBegin;
 61:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Abs), nullptr));
 62:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBY), nullptr));
 63:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBYPCZ), nullptr));
 64:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPY), nullptr));
 65:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AYPX), nullptr));
 66:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Conjugate), nullptr));
 67:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Copy), nullptr));
 68:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Exp), nullptr));
 69:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Log), nullptr));
 70:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(MAXPY), nullptr));
 71:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseDivide), nullptr));
 72:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMax), nullptr));
 73:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMaxAbs), nullptr));
 74:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMin), nullptr));
 75:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMult), nullptr));
 76:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseSign), nullptr));
 77:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Reciprocal), nullptr));
 78:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Scale), nullptr));
 79:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Set), nullptr));
 80:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Shift), nullptr));
 81:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(SqrtAbs), nullptr));
 82:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Swap), nullptr));
 83:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(WAXPY), nullptr));
 84:   PetscFunctionReturn(PETSC_SUCCESS);
 85: }

 87: template <device::cupm::DeviceType T>
 88: inline PetscErrorCode VecSeq_CUPM<T>::InitializeAsyncFunctions(Vec v) noexcept
 89: {
 90:   PetscFunctionBegin;
 91:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Abs), VecSeq_CUPM<T>::AbsAsync));
 92:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBY), VecSeq_CUPM<T>::AXPBYAsync));
 93:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBYPCZ), VecSeq_CUPM<T>::AXPBYPCZAsync));
 94:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPY), VecSeq_CUPM<T>::AXPYAsync));
 95:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AYPX), VecSeq_CUPM<T>::AYPXAsync));
 96:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Conjugate), VecSeq_CUPM<T>::ConjugateAsync));
 97:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Copy), VecSeq_CUPM<T>::CopyAsync));
 98:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Exp), VecSeq_CUPM<T>::ExpAsync));
 99:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Log), VecSeq_CUPM<T>::LogAsync));
100:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(MAXPY), VecSeq_CUPM<T>::MAXPYAsync));
101:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseDivide), VecSeq_CUPM<T>::PointwiseDivideAsync));
102:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMax), VecSeq_CUPM<T>::PointwiseMaxAsync));
103:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMaxAbs), VecSeq_CUPM<T>::PointwiseMaxAbsAsync));
104:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMin), VecSeq_CUPM<T>::PointwiseMinAsync));
105:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMult), VecSeq_CUPM<T>::PointwiseMultAsync));
106:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseSign), VecSeq_CUPM<T>::PointwiseSignAsync));
107:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Reciprocal), VecSeq_CUPM<T>::ReciprocalAsync));
108:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Scale), VecSeq_CUPM<T>::ScaleAsync));
109:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Set), VecSeq_CUPM<T>::SetAsync));
110:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Shift), VecSeq_CUPM<T>::ShiftAsync));
111:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(SqrtAbs), VecSeq_CUPM<T>::SqrtAbsAsync));
112:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Swap), VecSeq_CUPM<T>::SwapAsync));
113:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(WAXPY), VecSeq_CUPM<T>::WAXPYAsync));
114:   PetscFunctionReturn(PETSC_SUCCESS);
115: }

117: template <device::cupm::DeviceType T>
118: inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
119: {
120:   PetscFunctionBegin;
121:   PetscCall(ClearAsyncFunctions(v));
122:   PetscCall(VecDestroy_Seq(v));
123:   PetscFunctionReturn(PETSC_SUCCESS);
124: }

126: template <device::cupm::DeviceType T>
127: inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
128: {
129:   return VecResetArray_Seq(v);
130: }

132: template <device::cupm::DeviceType T>
133: inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
134: {
135:   return VecPlaceArray_Seq(v, a);
136: }

138: template <device::cupm::DeviceType T>
139: inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
140: {
141:   PetscMPIInt size;

143:   PetscFunctionBegin;
144:   if (alloc_missing) *alloc_missing = PETSC_FALSE;
145:   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
146:   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
147:   PetscCall(VecCreate_Seq_Private(v, host_array));
148:   PetscCall(InitializeAsyncFunctions(v));
149:   PetscFunctionReturn(PETSC_SUCCESS);
150: }

152: // for functions with an early return based one vec size we still need to artificially bump the
153: // object state. This is to prevent the following:
154: //
155: // 0. Suppose you have a Vec {
156: //   rank 0: [0],
157: //   rank 1: []
158: // }
159: // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
160: // 2. Vec enters e.g. VecSet(10)
161: // 3. rank 1 has local size 0 and bails immediately
162: // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
163: // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
164: // 6. Vec enters VecNorm(), and calls VecNormAvailable()
165: // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
166: // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
167: // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
168: template <device::cupm::DeviceType T>
169: inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
170: {
171:   PetscFunctionBegin;
172:   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
173:   PetscFunctionReturn(PETSC_SUCCESS);
174: }

176: template <device::cupm::DeviceType T>
177: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
178: {
179:   PetscFunctionBegin;
180:   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
181:   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
182:   PetscFunctionReturn(PETSC_SUCCESS);
183: }

185: template <device::cupm::DeviceType T>
186: template <typename BinaryFuncT>
187: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout, PetscDeviceContext dctx) noexcept
188: {
189:   PetscFunctionBegin;
190:   if (const auto n = zout->map->n) {
191:     cupmStream_t stream;

193:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
194:     PetscCall(GetHandlesFrom_(dctx, &stream));
195:     // clang-format off
196:     PetscCallThrust(
197:       const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());

199:       THRUST_CALL(
200:         thrust::transform,
201:         stream,
202:         dxptr, dxptr + n,
203:         thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
204:         thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
205:         std::forward<BinaryFuncT>(binary)
206:       )
207:     );
208:     // clang-format on
209:     PetscCall(PetscLogGpuFlops(n));
210:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
211:   } else {
212:     PetscCall(MaybeIncrementEmptyLocalVec(zout));
213:   }
214:   PetscFunctionReturn(PETSC_SUCCESS);
215: }

217: template <device::cupm::DeviceType T>
218: template <typename BinaryFuncT>
219: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinaryDispatch_(PetscErrorCode (*VecSeqFunction)(Vec, Vec, Vec), BinaryFuncT &&binary, Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
220: {
221:   PetscFunctionBegin;
222:   if (xin->boundtocpu || yin->boundtocpu) PetscCall((*VecSeqFunction)(wout, xin, yin));
223:   else PetscCall(PointwiseBinary_(std::forward<BinaryFuncT>(binary), xin, yin, wout, dctx)); // note order of arguments! xin and yin are read, wout is written!
224:   PetscFunctionReturn(PETSC_SUCCESS);
225: }

227: template <device::cupm::DeviceType T>
228: template <typename UnaryFuncT>
229: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yout, PetscDeviceContext dctx) noexcept
230: {
231:   const auto inplace = !yout || (xinout == yout);

233:   PetscFunctionBegin;
234:   if (const auto n = xinout->map->n) {
235:     cupmStream_t stream;
236:     const auto   apply = [&](PetscScalar *xinout, PetscScalar *yout = nullptr) {
237:       PetscFunctionBegin;
238:       // clang-format off
239:       PetscCallThrust(
240:         const auto xptr = thrust::device_pointer_cast(xinout);

242:         THRUST_CALL(
243:           thrust::transform,
244:           stream,
245:           xptr, xptr + n,
246:           (yout && (yout != xinout)) ? thrust::device_pointer_cast(yout) : xptr,
247:           std::forward<UnaryFuncT>(unary)
248:         )
249:       );
250:       // clang-format on
251:       PetscFunctionReturn(PETSC_SUCCESS);
252:     };

254:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
255:     PetscCall(GetHandlesFrom_(dctx, &stream));
256:     if (inplace) {
257:       PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
258:     } else {
259:       PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yout).data()));
260:     }
261:     PetscCall(PetscLogGpuFlops(n));
262:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
263:   } else {
264:     if (inplace) {
265:       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
266:     } else {
267:       PetscCall(MaybeIncrementEmptyLocalVec(yout));
268:     }
269:   }
270:   PetscFunctionReturn(PETSC_SUCCESS);
271: }

273: // ==========================================================================================
274: // VecSeq_CUPM - Public API - Constructors
275: // ==========================================================================================

277: // VecCreateSeqCUPM()
278: template <device::cupm::DeviceType T>
279: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
280: {
281:   PetscFunctionBegin;
282:   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
283:   PetscFunctionReturn(PETSC_SUCCESS);
284: }

286: // VecCreateSeqCUPMWithArrays()
287: template <device::cupm::DeviceType T>
288: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
289: {
290:   PetscDeviceContext dctx;

292:   PetscFunctionBegin;
293:   PetscCall(GetHandles_(&dctx));
294:   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
295:   // CreateSeqCUPM_() is called!
296:   PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE));
297:   PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
298:   PetscFunctionReturn(PETSC_SUCCESS);
299: }

301: // v->ops->duplicate
302: template <device::cupm::DeviceType T>
303: inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
304: {
305:   PetscDeviceContext dctx;

307:   PetscFunctionBegin;
308:   PetscCall(GetHandles_(&dctx));
309:   PetscCall(Duplicate_CUPMBase(v, y, dctx));
310:   PetscFunctionReturn(PETSC_SUCCESS);
311: }

313: // ==========================================================================================
314: // VecSeq_CUPM - Public API - Utility
315: // ==========================================================================================

317: // v->ops->bindtocpu
318: template <device::cupm::DeviceType T>
319: inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
320: {
321:   PetscDeviceContext dctx;

323:   PetscFunctionBegin;
324:   PetscCall(GetHandles_(&dctx));
325:   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));

327:   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
328:   VecSetOp_CUPM(dot, VecDot_Seq, Dot);
329:   VecSetOp_CUPM(norm, VecNorm_Seq, Norm);
330:   VecSetOp_CUPM(tdot, VecTDot_Seq, TDot);
331:   VecSetOp_CUPM(mdot, VecMDot_Seq, MDot);
332:   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
333:   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
334:   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
335:   VecSetOp_CUPM(max, VecMax_Seq, Max);
336:   VecSetOp_CUPM(min, VecMin_Seq, Min);
337:   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO);
338:   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO);
339:   PetscFunctionReturn(PETSC_SUCCESS);
340: }

342: // ==========================================================================================
343: // VecSeq_CUPM - Public API - Mutators
344: // ==========================================================================================

346: // v->ops->getlocalvector or v->ops->getlocalvectorread
347: template <device::cupm::DeviceType T>
348: template <PetscMemoryAccessMode access>
349: inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept
350: {
351:   PetscBool wisseqcupm;

353:   PetscFunctionBegin;
354:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
355:   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
356:   if (wisseqcupm) {
357:     if (const auto wseq = VecIMPLCast(w)) {
358:       if (auto &alloced = wseq->array_allocated) {
359:         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));

361:         PetscCall(PetscFree(alloced));
362:       }
363:       wseq->array         = nullptr;
364:       wseq->unplacedarray = nullptr;
365:     }
366:     if (const auto wcu = VecCUPMCast(w)) {
367:       if (auto &device_array = wcu->array_d) {
368:         cupmStream_t stream;

370:         PetscCall(GetHandles_(&stream));
371:         PetscCallCUPM(cupmFreeAsync(device_array, stream));
372:       }
373:       PetscCall(PetscFree(w->spptr /* wcu */));
374:     }
375:   }
376:   if (v->petscnative && wisseqcupm) {
377:     PetscCall(PetscFree(w->data));
378:     w->data          = v->data;
379:     w->offloadmask   = v->offloadmask;
380:     w->pinned_memory = v->pinned_memory;
381:     w->spptr         = v->spptr;
382:     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
383:   } else {
384:     const auto array = &VecIMPLCast(w)->array;

386:     if (access == PETSC_MEMORY_ACCESS_READ) {
387:       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
388:     } else {
389:       PetscCall(VecGetArray(v, array));
390:     }
391:     w->offloadmask = PETSC_OFFLOAD_CPU;
392:     if (wisseqcupm) {
393:       PetscDeviceContext dctx;

395:       PetscCall(GetHandles_(&dctx));
396:       PetscCall(DeviceAllocateCheck_(dctx, w));
397:     }
398:   }
399:   PetscFunctionReturn(PETSC_SUCCESS);
400: }

402: // v->ops->restorelocalvector or v->ops->restorelocalvectorread
403: template <device::cupm::DeviceType T>
404: template <PetscMemoryAccessMode access>
405: inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept
406: {
407:   PetscBool wisseqcupm;

409:   PetscFunctionBegin;
410:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
411:   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
412:   if (v->petscnative && wisseqcupm) {
413:     // the assignments to nullptr are __critical__, as w may persist after this call returns
414:     // and shouldn't share data with v!
415:     v->pinned_memory = w->pinned_memory;
416:     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
417:     v->data          = util::exchange(w->data, nullptr);
418:     v->spptr         = util::exchange(w->spptr, nullptr);
419:   } else {
420:     const auto array = &VecIMPLCast(w)->array;

422:     if (access == PETSC_MEMORY_ACCESS_READ) {
423:       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
424:     } else {
425:       PetscCall(VecRestoreArray(v, array));
426:     }
427:     if (w->spptr && wisseqcupm) {
428:       cupmStream_t stream;

430:       PetscCall(GetHandles_(&stream));
431:       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
432:       PetscCall(PetscFree(w->spptr));
433:     }
434:   }
435:   PetscFunctionReturn(PETSC_SUCCESS);
436: }

438: // ==========================================================================================
439: // VecSeq_CUPM - Public API - Compute Methods
440: // ==========================================================================================

442: // VecAYPXAsync_Private
443: template <device::cupm::DeviceType T>
444: inline PetscErrorCode VecSeq_CUPM<T>::AYPXAsync(Vec yin, PetscScalar alpha, Vec xin, PetscDeviceContext dctx) noexcept
445: {
446:   const auto n = static_cast<cupmBlasInt_t>(yin->map->n);
447:   PetscBool  xiscupm;

449:   PetscFunctionBegin;
450:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
451:   if (!xiscupm) {
452:     PetscCall(VecAYPX_Seq(yin, alpha, xin));
453:     PetscFunctionReturn(PETSC_SUCCESS);
454:   }
455:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
456:   if (alpha == PetscScalar(0.0)) {
457:     cupmStream_t stream;

459:     PetscCall(GetHandlesFrom_(dctx, &stream));
460:     PetscCall(PetscLogGpuTimeBegin());
461:     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
462:     PetscCall(PetscLogGpuTimeEnd());
463:   } else if (n) {
464:     const auto       alphaIsOne = alpha == PetscScalar(1.0);
465:     const auto       calpha     = cupmScalarPtrCast(&alpha);
466:     cupmBlasHandle_t cupmBlasHandle;

468:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
469:     {
470:       const auto yptr = DeviceArrayReadWrite(dctx, yin);
471:       const auto xptr = DeviceArrayRead(dctx, xin);

473:       PetscCall(PetscLogGpuTimeBegin());
474:       if (alphaIsOne) {
475:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
476:       } else {
477:         const auto one = cupmScalarCast(1.0);

479:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
480:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
481:       }
482:       PetscCall(PetscLogGpuTimeEnd());
483:     }
484:     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
485:   }
486:   if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
487:   PetscFunctionReturn(PETSC_SUCCESS);
488: }

490: // v->ops->aypx
491: template <device::cupm::DeviceType T>
492: inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept
493: {
494:   PetscFunctionBegin;
495:   PetscCall(AYPXAsync(yin, alpha, xin, nullptr));
496:   PetscFunctionReturn(PETSC_SUCCESS);
497: }

499: // VecAXPYAsync_Private
500: template <device::cupm::DeviceType T>
501: inline PetscErrorCode VecSeq_CUPM<T>::AXPYAsync(Vec yin, PetscScalar alpha, Vec xin, PetscDeviceContext dctx) noexcept
502: {
503:   PetscBool xiscupm;

505:   PetscFunctionBegin;
506:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
507:   if (xiscupm) {
508:     const auto       n = static_cast<cupmBlasInt_t>(yin->map->n);
509:     cupmBlasHandle_t cupmBlasHandle;

511:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
512:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
513:     PetscCall(PetscLogGpuTimeBegin());
514:     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
515:     PetscCall(PetscLogGpuTimeEnd());
516:     PetscCall(PetscLogGpuFlops(2 * n));
517:     if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
518:   } else {
519:     PetscCall(VecAXPY_Seq(yin, alpha, xin));
520:   }
521:   PetscFunctionReturn(PETSC_SUCCESS);
522: }

524: // v->ops->axpy
525: template <device::cupm::DeviceType T>
526: inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept
527: {
528:   PetscFunctionBegin;
529:   PetscCall(AXPYAsync(yin, alpha, xin, nullptr));
530:   PetscFunctionReturn(PETSC_SUCCESS);
531: }

533: namespace detail
534: {

536: struct divides {
537:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return rhs == PetscScalar{0.0} ? rhs : lhs / rhs; }
538: };

540: } // namespace detail

542: // VecPointwiseDivideAsync_Private
543: template <device::cupm::DeviceType T>
544: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivideAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
545: {
546:   PetscFunctionBegin;
547:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseDivide_Seq, detail::divides{}, wout, xin, yin, dctx));
548:   PetscFunctionReturn(PETSC_SUCCESS);
549: }

551: // v->ops->pointwisedivide
552: template <device::cupm::DeviceType T>
553: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec wout, Vec xin, Vec yin) noexcept
554: {
555:   PetscFunctionBegin;
556:   PetscCall(PointwiseDivideAsync(wout, xin, yin, nullptr));
557:   PetscFunctionReturn(PETSC_SUCCESS);
558: }

560: // VecPointwiseMultAsync_Private
561: template <device::cupm::DeviceType T>
562: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMultAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
563: {
564:   PetscFunctionBegin;
565:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMult_Seq, thrust::multiplies<PetscScalar>{}, wout, xin, yin, dctx));
566:   PetscFunctionReturn(PETSC_SUCCESS);
567: }

569: // v->ops->pointwisemult
570: template <device::cupm::DeviceType T>
571: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec wout, Vec xin, Vec yin) noexcept
572: {
573:   PetscFunctionBegin;
574:   PetscCall(PointwiseMultAsync(wout, xin, yin, nullptr));
575:   PetscFunctionReturn(PETSC_SUCCESS);
576: }

578: namespace detail
579: {

581: struct MaximumRealPart {
582:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::maximum<PetscReal>{}(PetscRealPart(lhs), PetscRealPart(rhs)); }
583: };

585: } // namespace detail

587: // VecPointwiseMaxAsync_Private
588: template <device::cupm::DeviceType T>
589: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
590: {
591:   PetscFunctionBegin;
592:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMax_Seq, detail::MaximumRealPart{}, wout, xin, yin, dctx));
593:   PetscFunctionReturn(PETSC_SUCCESS);
594: }

596: // v->ops->pointwisemax
597: template <device::cupm::DeviceType T>
598: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMax(Vec wout, Vec xin, Vec yin) noexcept
599: {
600:   PetscFunctionBegin;
601:   PetscCall(PointwiseMaxAsync(wout, xin, yin, nullptr));
602:   PetscFunctionReturn(PETSC_SUCCESS);
603: }

605: namespace detail
606: {

608: struct MaximumAbsoluteValue {
609:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::maximum<PetscReal>{}(PetscAbsScalar(lhs), PetscAbsScalar(rhs)); }
610: };

612: } // namespace detail

614: // VecPointwiseMaxAbsAsync_Private
615: template <device::cupm::DeviceType T>
616: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAbsAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
617: {
618:   PetscFunctionBegin;
619:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMaxAbs_Seq, detail::MaximumAbsoluteValue{}, wout, xin, yin, dctx));
620:   PetscFunctionReturn(PETSC_SUCCESS);
621: }

623: // v->ops->pointwisemaxabs
624: template <device::cupm::DeviceType T>
625: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAbs(Vec wout, Vec xin, Vec yin) noexcept
626: {
627:   PetscFunctionBegin;
628:   PetscCall(PointwiseMaxAbsAsync(wout, xin, yin, nullptr));
629:   PetscFunctionReturn(PETSC_SUCCESS);
630: }

632: namespace detail
633: {

635: struct MinimumRealPart {
636:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::minimum<PetscReal>{}(PetscRealPart(lhs), PetscRealPart(rhs)); }
637: };

639: } // namespace detail

641: // VecPointwiseMinAsync_Private
642: template <device::cupm::DeviceType T>
643: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMinAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
644: {
645:   PetscFunctionBegin;
646:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMin_Seq, detail::MinimumRealPart{}, wout, xin, yin, dctx));
647:   PetscFunctionReturn(PETSC_SUCCESS);
648: }

650: // v->ops->pointwisemin
651: template <device::cupm::DeviceType T>
652: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMin(Vec wout, Vec xin, Vec yin) noexcept
653: {
654:   PetscFunctionBegin;
655:   PetscCall(PointwiseMinAsync(wout, xin, yin, nullptr));
656:   PetscFunctionReturn(PETSC_SUCCESS);
657: }

659: namespace detail
660: {

662: struct Reciprocal {
663:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept
664:   {
665:     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
666:     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
667:     // everything in PetscScalar...
668:     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
669:   }
670: };

672: } // namespace detail

674: // VecReciprocalAsync_Private
675: template <device::cupm::DeviceType T>
676: inline PetscErrorCode VecSeq_CUPM<T>::ReciprocalAsync(Vec xin, PetscDeviceContext dctx) noexcept
677: {
678:   PetscFunctionBegin;
679:   PetscCall(PointwiseUnary_(detail::Reciprocal{}, xin, nullptr, dctx));
680:   PetscFunctionReturn(PETSC_SUCCESS);
681: }

683: // v->ops->reciprocal
684: template <device::cupm::DeviceType T>
685: inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept
686: {
687:   PetscFunctionBegin;
688:   PetscCall(ReciprocalAsync(xin, nullptr));
689:   PetscFunctionReturn(PETSC_SUCCESS);
690: }

692: namespace detail
693: {

695: struct AbsoluteValue {
696:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscAbsScalar(s); }
697: };

699: } // namespace detail

701: // VecAbsAsync_Private
702: template <device::cupm::DeviceType T>
703: inline PetscErrorCode VecSeq_CUPM<T>::AbsAsync(Vec xin, PetscDeviceContext dctx) noexcept
704: {
705:   PetscFunctionBegin;
706:   PetscCall(PointwiseUnary_(detail::AbsoluteValue{}, xin, nullptr, dctx));
707:   PetscFunctionReturn(PETSC_SUCCESS);
708: }

710: // v->ops->abs
711: template <device::cupm::DeviceType T>
712: inline PetscErrorCode VecSeq_CUPM<T>::Abs(Vec xin) noexcept
713: {
714:   PetscFunctionBegin;
715:   PetscCall(AbsAsync(xin, nullptr));
716:   PetscFunctionReturn(PETSC_SUCCESS);
717: }

719: namespace detail
720: {

722: struct SignZeroToSignedUnit {
723:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return VecSignZeroToSignedUnit_Private(PetscRealPart(s)); }
724: };

726: struct SignZeroToZero {
727:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return VecSignZeroToZero_Private(PetscRealPart(s)); }
728: };

730: struct SignZeroToSignedZero {
731:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return VecSignZeroToSignedZero_Private(PetscRealPart(s)); }
732: };

734: } // namespace detail

736: // VecPointwiseSignAsync_Private
737: template <device::cupm::DeviceType T>
738: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseSignAsync(Vec yout, Vec xin, VecSignMode sign_type, PetscDeviceContext dctx) noexcept
739: {
740:   PetscFunctionBegin;
741:   switch (sign_type) {
742:   case VEC_SIGN_ZERO_TO_ZERO:
743:     PetscCall(PointwiseUnary_(detail::SignZeroToZero{}, xin, yout, dctx));
744:     break;
745:   case VEC_SIGN_ZERO_TO_SIGNED_ZERO:
746:     PetscCall(PointwiseUnary_(detail::SignZeroToSignedZero{}, xin, yout, dctx));
747:     break;
748:   case VEC_SIGN_ZERO_TO_SIGNED_UNIT:
749:     PetscCall(PointwiseUnary_(detail::SignZeroToSignedUnit{}, xin, yout, dctx));
750:     break;
751:   }
752:   PetscFunctionReturn(PETSC_SUCCESS);
753: }

755: namespace detail
756: {

758: struct SquareRootAbsoluteValue {
759:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscSqrtReal(PetscAbsScalar(s)); }
760: };

762: } // namespace detail

764: // VecSqrtAbsAsync_Private
765: template <device::cupm::DeviceType T>
766: inline PetscErrorCode VecSeq_CUPM<T>::SqrtAbsAsync(Vec xin, PetscDeviceContext dctx) noexcept
767: {
768:   PetscFunctionBegin;
769:   PetscCall(PointwiseUnary_(detail::SquareRootAbsoluteValue{}, xin, nullptr, dctx));
770:   PetscFunctionReturn(PETSC_SUCCESS);
771: }

773: // v->ops->sqrt
774: template <device::cupm::DeviceType T>
775: inline PetscErrorCode VecSeq_CUPM<T>::SqrtAbs(Vec xin) noexcept
776: {
777:   PetscFunctionBegin;
778:   PetscCall(SqrtAbsAsync(xin, nullptr));
779:   PetscFunctionReturn(PETSC_SUCCESS);
780: }

782: namespace detail
783: {

785: struct Exponent {
786:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscExpScalar(s); }
787: };

789: } // namespace detail

791: // VecExpAsync_Private
792: template <device::cupm::DeviceType T>
793: inline PetscErrorCode VecSeq_CUPM<T>::ExpAsync(Vec xin, PetscDeviceContext dctx) noexcept
794: {
795:   PetscFunctionBegin;
796:   PetscCall(PointwiseUnary_(detail::Exponent{}, xin, nullptr, dctx));
797:   PetscFunctionReturn(PETSC_SUCCESS);
798: }

800: // v->ops->exp
801: template <device::cupm::DeviceType T>
802: inline PetscErrorCode VecSeq_CUPM<T>::Exp(Vec xin) noexcept
803: {
804:   PetscFunctionBegin;
805:   PetscCall(ExpAsync(xin, nullptr));
806:   PetscFunctionReturn(PETSC_SUCCESS);
807: }

809: namespace detail
810: {

812: struct Logarithm {
813:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscLogScalar(s); }
814: };

816: } // namespace detail

818: // VecLogAsync_Private
819: template <device::cupm::DeviceType T>
820: inline PetscErrorCode VecSeq_CUPM<T>::LogAsync(Vec xin, PetscDeviceContext dctx) noexcept
821: {
822:   PetscFunctionBegin;
823:   PetscCall(PointwiseUnary_(detail::Logarithm{}, xin, nullptr, dctx));
824:   PetscFunctionReturn(PETSC_SUCCESS);
825: }

827: // v->ops->log
828: template <device::cupm::DeviceType T>
829: inline PetscErrorCode VecSeq_CUPM<T>::Log(Vec xin) noexcept
830: {
831:   PetscFunctionBegin;
832:   PetscCall(LogAsync(xin, nullptr));
833:   PetscFunctionReturn(PETSC_SUCCESS);
834: }

836: // v->ops->waxpy
837: template <device::cupm::DeviceType T>
838: inline PetscErrorCode VecSeq_CUPM<T>::WAXPYAsync(Vec win, PetscScalar alpha, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
839: {
840:   PetscBool xiscupm, yiscupm;

842:   PetscFunctionBegin;
843:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
844:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
845:   if (!xiscupm || !yiscupm) {
846:     PetscCall(VecWAXPY_Seq(win, alpha, xin, yin));
847:     PetscFunctionReturn(PETSC_SUCCESS);
848:   }
849:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
850:   if (alpha == PetscScalar(0.0)) {
851:     PetscCall(CopyAsync(yin, win, dctx));
852:   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
853:     cupmBlasHandle_t cupmBlasHandle;
854:     cupmStream_t     stream;
855:     PetscBool        xiscupm, yiscupm;

857:     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
858:     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
859:     if (!xiscupm || !yiscupm) {
860:       PetscCall(VecWAXPY_Seq(win, alpha, xin, yin));
861:       PetscFunctionReturn(PETSC_SUCCESS);
862:     }
863:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle, NULL, &stream));
864:     {
865:       const auto wptr = DeviceArrayWrite(dctx, win);

867:       PetscCall(PetscLogGpuTimeBegin());
868:       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
869:       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
870:       PetscCall(PetscLogGpuTimeEnd());
871:     }
872:     PetscCall(PetscLogGpuFlops(2 * n));
873:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
874:   }
875:   PetscFunctionReturn(PETSC_SUCCESS);
876: }

878: // v->ops->waxpy
879: template <device::cupm::DeviceType T>
880: inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
881: {
882:   PetscFunctionBegin;
883:   PetscCall(WAXPYAsync(win, alpha, xin, yin, nullptr));
884:   PetscFunctionReturn(PETSC_SUCCESS);
885: }

887: namespace kernels
888: {

890: template <typename... Args>
891: PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
892: {
893:   constexpr int      N        = sizeof...(Args);
894:   const auto         tx       = threadIdx.x;
895:   const PetscScalar *yptr_p[] = {yptr...};

897:   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];

899:   // load a to shared memory
900:   if (tx < N) aptr_shmem[tx] = aptr[tx];
901:   __syncthreads();

903:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
904:   // these may look the same but give different results!
905: #if 0
906:     PetscScalar sum = 0.0;

908:   #pragma unroll
909:     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
910:     xptr[i] += sum;
911: #else
912:     auto sum = xptr[i];

914:   #pragma unroll
915:     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j] * yptr_p[j][i];
916:     xptr[i] = sum;
917: #endif
918:   });
919:   return;
920: }

922: } // namespace kernels

924: namespace detail
925: {

927: // a helper-struct to gobble the size_t input, it is used with template parameter pack
928: // expansion such that
929: // typename repeat_type...
930: // expands to
931: // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
932: template <typename T, std::size_t>
933: struct repeat_type {
934:   using type = T;
935: };

937: } // namespace detail

939: template <device::cupm::DeviceType T>
940: template <std::size_t... Idx>
941: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
942: {
943:   PetscFunctionBegin;
944:   // clang-format off
945:   PetscCall(
946:     PetscCUPMLaunchKernel1D(
947:       size, 0, stream,
948:       kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
949:       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
950:     )
951:   );
952:   // clang-format on
953:   PetscFunctionReturn(PETSC_SUCCESS);
954: }

956: template <device::cupm::DeviceType T>
957: template <int N>
958: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
959: {
960:   PetscFunctionBegin;
961:   PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
962:   yidx += N;
963:   PetscFunctionReturn(PETSC_SUCCESS);
964: }

966: // VecMAXPYAsync_Private
967: template <device::cupm::DeviceType T>
968: inline PetscErrorCode VecSeq_CUPM<T>::MAXPYAsync(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin, PetscDeviceContext dctx) noexcept
969: {
970:   const auto   n = xin->map->n;
971:   cupmStream_t stream;
972:   PetscBool    yiscupm = PETSC_TRUE;

974:   PetscFunctionBegin;
975:   for (PetscInt i = 0; i < nv && yiscupm; i++) PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin[i]), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
976:   if (!yiscupm) {
977:     PetscCall(VecMAXPY_Seq(xin, nv, alpha, yin));
978:     PetscFunctionReturn(PETSC_SUCCESS);
979:   }
980:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
981:   PetscCall(GetHandlesFrom_(dctx, &stream));
982:   {
983:     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
984:     PetscScalar *d_alpha = nullptr;
985:     PetscInt     yidx    = 0;

987:     // placement of early-return is deliberate, we would like to capture the
988:     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
989:     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
990:     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
991:     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
992:     PetscCall(PetscLogGpuTimeBegin());
993:     do {
994:       switch (nv - yidx) {
995:       case 7:
996:         PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
997:         break;
998:       case 6:
999:         PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1000:         break;
1001:       case 5:
1002:         PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1003:         break;
1004:       case 4:
1005:         PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1006:         break;
1007:       case 3:
1008:         PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1009:         break;
1010:       case 2:
1011:         PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1012:         break;
1013:       case 1:
1014:         PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1015:         break;
1016:       default: // 8 or more
1017:         PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1018:         break;
1019:       }
1020:     } while (yidx < nv);
1021:     PetscCall(PetscLogGpuTimeEnd());
1022:     PetscCall(PetscDeviceFree(dctx, d_alpha));
1023:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1024:   }
1025:   PetscCall(PetscLogGpuFlops(nv * 2 * n));
1026:   PetscFunctionReturn(PETSC_SUCCESS);
1027: }

1029: // v->ops->maxpy
1030: template <device::cupm::DeviceType T>
1031: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
1032: {
1033:   PetscFunctionBegin;
1034:   PetscCall(MAXPYAsync(xin, nv, alpha, yin, nullptr));
1035:   PetscFunctionReturn(PETSC_SUCCESS);
1036: }

1038: template <device::cupm::DeviceType T>
1039: inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept
1040: {
1041:   PetscBool yiscupm;

1043:   PetscFunctionBegin;
1044:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1045:   if (!yiscupm) {
1046:     PetscCall(VecDot_Seq(xin, yin, z));
1047:     PetscFunctionReturn(PETSC_SUCCESS);
1048:   }
1049:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1050:     PetscDeviceContext dctx;
1051:     cupmBlasHandle_t   cupmBlasHandle;

1053:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1054:     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
1055:     // second
1056:     PetscCall(PetscLogGpuTimeBegin());
1057:     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
1058:     PetscCall(PetscLogGpuTimeEnd());
1059:     PetscCall(PetscLogGpuFlops(2 * n - 1));
1060:   } else {
1061:     *z = 0.0;
1062:   }
1063:   PetscFunctionReturn(PETSC_SUCCESS);
1064: }

1066: #define MDOT_WORKGROUP_NUM  128
1067: #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM

1069: namespace kernels
1070: {

1072: PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
1073: {
1074:   const auto group_entries = (size - 1) / gridDim.x + 1;
1075:   // for very small vectors, a group should still do some work
1076:   return group_entries ? group_entries : 1;
1077: }

1079: template <typename... ConstPetscScalarPointer>
1080: PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
1081: {
1082:   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
1083:   const PetscScalar *ylocal[] = {y...};
1084:   PetscScalar        sumlocal[N];

1086:   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];

1088:   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
1089:   // types, so each of these go on separate lines...
1090:   const auto tx       = threadIdx.x;
1091:   const auto bx       = blockIdx.x;
1092:   const auto bdx      = blockDim.x;
1093:   const auto gdx      = gridDim.x;
1094:   const auto worksize = EntriesPerGroup(size);
1095:   const auto begin    = tx + bx * worksize;
1096:   const auto end      = min((bx + 1) * worksize, size);

1098: #pragma unroll
1099:   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;

1101:   for (auto i = begin; i < end; i += bdx) {
1102:     const auto xi = x[i]; // load only once from global memory!

1104: #pragma unroll
1105:     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
1106:   }

1108: #pragma unroll
1109:   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];

1111:   // parallel reduction
1112:   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
1113:     __syncthreads();
1114:     if (tx < stride) {
1115: #pragma unroll
1116:       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
1117:     }
1118:   }
1119:   // bottom N threads per block write to global memory
1120:   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
1121:   // writes to the same sections in the above loop that it is about to read from below, but
1122:   // running this under the racecheck tool of compute-sanitizer reports a write-after-write hazard.
1123:   __syncthreads();
1124:   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
1125:   return;
1126: }

1128: namespace
1129: {

1131: PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
1132: {
1133:   int         local_i = 0;
1134:   PetscScalar local_results[8];

1136:   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
1137:   //
1138:   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
1139:   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
1140:   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
1141:   //  |  ______________________________________________________/
1142:   //  | /            <- MDOT_WORKGROUP_NUM ->
1143:   //  |/
1144:   //  +
1145:   //  v
1146:   // *-*-*
1147:   // | | | ...
1148:   // *-*-*
1149:   //
1150:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
1151:     PetscScalar z_sum = 0;

1153:     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
1154:     local_results[local_i++] = z_sum;
1155:   });
1156:   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
1157:   // may currently be reading from results
1158:   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
1159:   // Local buffer is now written to global memory
1160:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
1161:     const auto j = --local_i;

1163:     if (j >= 0) results[i] = local_results[j];
1164:   });
1165:   return;
1166: }

1168: } // namespace

1170: #if PetscDefined(USING_HCC)
1171: namespace do_not_use
1172: {

1174: inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1175: {
1176:   (void)sum_kernel;
1177: }

1179: } // namespace do_not_use
1180: #endif

1182: } // namespace kernels

1184: template <device::cupm::DeviceType T>
1185: template <std::size_t... Idx>
1186: inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
1187: {
1188:   PetscFunctionBegin;
1189:   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
1190:   // 128 blocks of 128 threads every time which may be wasteful
1191:   // clang-format off
1192:   PetscCallCUPM(
1193:     cupmLaunchKernel(
1194:       kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
1195:       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
1196:       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
1197:     )
1198:   );
1199:   // clang-format on
1200:   PetscFunctionReturn(PETSC_SUCCESS);
1201: }

1203: template <device::cupm::DeviceType T>
1204: template <int N>
1205: inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
1206: {
1207:   PetscFunctionBegin;
1208:   PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
1209:   yidx += N;
1210:   PetscFunctionReturn(PETSC_SUCCESS);
1211: }

1213: template <device::cupm::DeviceType T>
1214: inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
1215: {
1216:   // the largest possible size of a batch
1217:   constexpr PetscInt batchsize = 8;
1218:   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
1219:   // do not create substreams. Note we don't create more than 8 streams, in practice we could
1220:   // not get more parallelism with higher numbers.
1221:   const auto   num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
1222:   const auto   n               = xin->map->n;
1223:   const auto   nwork           = nv * MDOT_WORKGROUP_NUM;
1224:   PetscScalar *d_results;
1225:   cupmStream_t stream;

1227:   PetscFunctionBegin;
1228:   PetscCall(GetHandlesFrom_(dctx, &stream));
1229:   // allocate scratchpad memory for the results of individual work groups
1230:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
1231:   {
1232:     const auto          xptr       = DeviceArrayRead(dctx, xin);
1233:     PetscInt            yidx       = 0;
1234:     auto                subidx     = 0;
1235:     auto                cur_stream = stream;
1236:     auto                cur_ctx    = dctx;
1237:     PetscDeviceContext *sub        = nullptr;
1238:     PetscStreamType     stype;

1240:     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
1241:     // sub. Ideally the parent context should also join in on the fork, but it is extremely
1242:     // fiddly to do so presently
1243:     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
1244:     if (stype == PETSC_STREAM_DEFAULT || stype == PETSC_STREAM_DEFAULT_WITH_BARRIER) stype = PETSC_STREAM_NONBLOCKING;
1245:     // If we have a default stream create nonblocking streams instead (as we can
1246:     // locally exploit the parallelism). Otherwise use the prescribed stream type.
1247:     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
1248:     PetscCall(PetscLogGpuTimeBegin());
1249:     do {
1250:       if (num_sub_streams) {
1251:         cur_ctx = sub[subidx++ % num_sub_streams];
1252:         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
1253:       }
1254:       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
1255:       // it is very likely better to do 4+5 rather than 8+1
1256:       switch (nv - yidx) {
1257:       case 7:
1258:         PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1259:         break;
1260:       case 6:
1261:         PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1262:         break;
1263:       case 5:
1264:         PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1265:         break;
1266:       case 4:
1267:         PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1268:         break;
1269:       case 3:
1270:         PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1271:         break;
1272:       case 2:
1273:         PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1274:         break;
1275:       case 1:
1276:         PetscCall(MDot_kernel_dispatch_<1>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1277:         break;
1278:       default: // 8 or more
1279:         PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1280:         break;
1281:       }
1282:     } while (yidx < nv);
1283:     PetscCall(PetscLogGpuTimeEnd());
1284:     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
1285:   }

1287:   PetscCall(PetscCUPMLaunchKernel1D(nv, 0, stream, kernels::sum_kernel, nv, d_results));
1288:   // copy result of device reduction to host
1289:   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nv, cupmMemcpyDeviceToHost, stream));
1290:   // do these now while final reduction is in flight
1291:   PetscCall(PetscLogGpuFlops(nwork));
1292:   PetscCall(PetscDeviceFree(dctx, d_results));
1293:   PetscFunctionReturn(PETSC_SUCCESS);
1294: }

1296: #undef MDOT_WORKGROUP_NUM
1297: #undef MDOT_WORKGROUP_SIZE

1299: template <device::cupm::DeviceType T>
1300: inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
1301: {
1302:   // probably not worth it to run more than 8 of these at a time?
1303:   const auto          n_sub = PetscMin(nv, 8);
1304:   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
1305:   const auto          xptr  = DeviceArrayRead(dctx, xin);
1306:   PetscScalar        *d_z;
1307:   PetscDeviceContext *subctx;
1308:   cupmStream_t        stream;

1310:   PetscFunctionBegin;
1311:   PetscCall(GetHandlesFrom_(dctx, &stream));
1312:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
1313:   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
1314:   PetscCall(PetscLogGpuTimeBegin());
1315:   for (PetscInt i = 0; i < nv; ++i) {
1316:     const auto            sub = subctx[i % n_sub];
1317:     cupmBlasHandle_t      handle;
1318:     cupmBlasPointerMode_t old_mode;

1320:     PetscCall(GetHandlesFrom_(sub, &handle));
1321:     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1322:     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1323:     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1324:     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1325:   }
1326:   PetscCall(PetscLogGpuTimeEnd());
1327:   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1328:   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1329:   PetscCall(PetscDeviceFree(dctx, d_z));
1330:   // REVIEW ME: flops?????
1331:   PetscFunctionReturn(PETSC_SUCCESS);
1332: }

1334: // v->ops->mdot
1335: template <device::cupm::DeviceType T>
1336: inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1337: {
1338:   PetscFunctionBegin;
1339:   if (PetscUnlikely(nv == 1)) {
1340:     // dot handles nv = 0 correctly
1341:     PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z));
1342:   } else if (const auto n = xin->map->n) {
1343:     PetscDeviceContext dctx;

1345:     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1346:     PetscCall(GetHandles_(&dctx));
1347:     PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1348:     // REVIEW ME: double count of flops??
1349:     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1350:     PetscCall(PetscDeviceContextSynchronize(dctx));
1351:   } else {
1352:     PetscCall(PetscArrayzero(z, nv));
1353:   }
1354:   PetscFunctionReturn(PETSC_SUCCESS);
1355: }

1357: // VecSetAsync_Private
1358: template <device::cupm::DeviceType T>
1359: inline PetscErrorCode VecSeq_CUPM<T>::SetAsync(Vec xin, PetscScalar alpha, PetscDeviceContext dctx) noexcept
1360: {
1361:   const auto   n = xin->map->n;
1362:   cupmStream_t stream;

1364:   PetscFunctionBegin;
1365:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1366:   PetscCall(GetHandlesFrom_(dctx, &stream));
1367:   {
1368:     const auto xptr = DeviceArrayWrite(dctx, xin);

1370:     if (alpha == PetscScalar(0.0)) {
1371:       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1372:     } else {
1373:       const auto dptr = thrust::device_pointer_cast(xptr.data());

1375:       PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1376:     }
1377:   }
1378:   if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1379:   PetscFunctionReturn(PETSC_SUCCESS);
1380: }

1382: // v->ops->set
1383: template <device::cupm::DeviceType T>
1384: inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept
1385: {
1386:   PetscFunctionBegin;
1387:   PetscCall(SetAsync(xin, alpha, nullptr));
1388:   PetscFunctionReturn(PETSC_SUCCESS);
1389: }

1391: // VecScaleAsync_Private
1392: template <device::cupm::DeviceType T>
1393: inline PetscErrorCode VecSeq_CUPM<T>::ScaleAsync(Vec xin, PetscScalar alpha, PetscDeviceContext dctx) noexcept
1394: {
1395:   PetscFunctionBegin;
1396:   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1397:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1398:   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1399:     PetscCall(SetAsync(xin, alpha, dctx));
1400:   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1401:     cupmBlasHandle_t cupmBlasHandle;

1403:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1404:     PetscCall(PetscLogGpuTimeBegin());
1405:     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1406:     PetscCall(PetscLogGpuTimeEnd());
1407:     PetscCall(PetscLogGpuFlops(n));
1408:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1409:   } else {
1410:     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1411:   }
1412:   PetscFunctionReturn(PETSC_SUCCESS);
1413: }

1415: // v->ops->scale
1416: template <device::cupm::DeviceType T>
1417: inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept
1418: {
1419:   PetscFunctionBegin;
1420:   PetscCall(ScaleAsync(xin, alpha, nullptr));
1421:   PetscFunctionReturn(PETSC_SUCCESS);
1422: }

1424: // v->ops->tdot
1425: template <device::cupm::DeviceType T>
1426: inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept
1427: {
1428:   PetscBool yiscupm;

1430:   PetscFunctionBegin;
1431:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1432:   if (!yiscupm) {
1433:     PetscCall(VecTDot_Seq(xin, yin, z));
1434:     PetscFunctionReturn(PETSC_SUCCESS);
1435:   }
1436:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1437:     PetscDeviceContext dctx;
1438:     cupmBlasHandle_t   cupmBlasHandle;

1440:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1441:     PetscCall(PetscLogGpuTimeBegin());
1442:     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1443:     PetscCall(PetscLogGpuTimeEnd());
1444:     PetscCall(PetscLogGpuFlops(2 * n - 1));
1445:   } else {
1446:     *z = 0.0;
1447:   }
1448:   PetscFunctionReturn(PETSC_SUCCESS);
1449: }

1451: // VecCopyAsync_Private
1452: template <device::cupm::DeviceType T>
1453: inline PetscErrorCode VecSeq_CUPM<T>::CopyAsync(Vec xin, Vec yout, PetscDeviceContext dctx) noexcept
1454: {
1455:   PetscFunctionBegin;
1456:   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1457:   if (const auto n = xin->map->n) {
1458:     const auto xmask = xin->offloadmask;
1459:     // silence buggy gcc warning: mode may be used uninitialized in this function
1460:     auto         mode = cupmMemcpyDeviceToDevice;
1461:     cupmStream_t stream;

1463:     // translate from PetscOffloadMask to cupmMemcpyKind
1464:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1465:     switch (const auto ymask = yout->offloadmask) {
1466:     case PETSC_OFFLOAD_CPU:
1467:     case PETSC_OFFLOAD_UNALLOCATED: {
1468:       PetscBool yiscupm;

1470:       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1471:       if (yiscupm && !yout->boundtocpu) {
1472:         /* If GPU vector, also ensure output is on GPU unless explicitly bound to CPU */
1473:         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1474:       } else {
1475:         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToHost : cupmMemcpyHostToHost;
1476:       }
1477:       break;
1478:     }
1479:     case PETSC_OFFLOAD_BOTH:
1480:     case PETSC_OFFLOAD_GPU:
1481:       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1482:       break;
1483:     default:
1484:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1485:     }

1487:     PetscCall(GetHandlesFrom_(dctx, &stream));
1488:     switch (mode) {
1489:     case cupmMemcpyDeviceToDevice: // the best case
1490:     case cupmMemcpyHostToDevice: { // not terrible
1491:       const auto yptr = DeviceArrayWrite(dctx, yout);
1492:       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();

1494:       PetscCall(PetscLogGpuTimeBegin());
1495:       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1496:       PetscCall(PetscLogGpuTimeEnd());
1497:     } break;
1498:     case cupmMemcpyDeviceToHost: // not great
1499:     case cupmMemcpyHostToHost: { // worst case
1500:       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1501:       PetscScalar *yptr;

1503:       PetscCall(VecGetArrayWrite(yout, &yptr));
1504:       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1505:       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1506:       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1507:       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1508:     } break;
1509:     default:
1510:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1511:     }
1512:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1513:   } else {
1514:     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1515:   }
1516:   PetscFunctionReturn(PETSC_SUCCESS);
1517: }

1519: // v->ops->copy
1520: template <device::cupm::DeviceType T>
1521: inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept
1522: {
1523:   PetscFunctionBegin;
1524:   PetscCall(CopyAsync(xin, yout, nullptr));
1525:   PetscFunctionReturn(PETSC_SUCCESS);
1526: }

1528: // VecSwapAsync_Private
1529: template <device::cupm::DeviceType T>
1530: inline PetscErrorCode VecSeq_CUPM<T>::SwapAsync(Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
1531: {
1532:   PetscBool yiscupm;

1534:   PetscFunctionBegin;
1535:   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1536:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1537:   PetscCheck(yiscupm, PetscObjectComm(PetscObjectCast(yin)), PETSC_ERR_SUP, "Cannot swap with Y of type %s", PetscObjectCast(yin)->type_name);
1538:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1539:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1540:     cupmBlasHandle_t cupmBlasHandle;

1542:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1543:     PetscCall(PetscLogGpuTimeBegin());
1544:     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1545:     PetscCall(PetscLogGpuTimeEnd());
1546:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1547:   } else {
1548:     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1549:     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1550:   }
1551:   PetscFunctionReturn(PETSC_SUCCESS);
1552: }

1554: // v->ops->swap
1555: template <device::cupm::DeviceType T>
1556: inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept
1557: {
1558:   PetscFunctionBegin;
1559:   PetscCall(SwapAsync(xin, yin, nullptr));
1560:   PetscFunctionReturn(PETSC_SUCCESS);
1561: }

1563: // VecAXPYBYAsync_Private
1564: template <device::cupm::DeviceType T>
1565: inline PetscErrorCode VecSeq_CUPM<T>::AXPBYAsync(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin, PetscDeviceContext dctx) noexcept
1566: {
1567:   PetscBool xiscupm;

1569:   PetscFunctionBegin;
1570:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1571:   if (!xiscupm) {
1572:     PetscCall(VecAXPBY_Seq(yin, alpha, beta, xin));
1573:     PetscFunctionReturn(PETSC_SUCCESS);
1574:   }
1575:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1576:   if (alpha == PetscScalar(0.0)) {
1577:     PetscCall(ScaleAsync(yin, beta, dctx));
1578:   } else if (beta == PetscScalar(1.0)) {
1579:     PetscCall(AXPYAsync(yin, alpha, xin, dctx));
1580:   } else if (alpha == PetscScalar(1.0)) {
1581:     PetscCall(AYPXAsync(yin, beta, xin, dctx));
1582:   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1583:     PetscBool xiscupm;

1585:     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1586:     if (!xiscupm) {
1587:       PetscCall(VecAXPBY_Seq(yin, alpha, beta, xin));
1588:       PetscFunctionReturn(PETSC_SUCCESS);
1589:     }

1591:     const auto       betaIsZero = beta == PetscScalar(0.0);
1592:     const auto       aptr       = cupmScalarPtrCast(&alpha);
1593:     cupmBlasHandle_t cupmBlasHandle;

1595:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1596:     {
1597:       const auto xptr = DeviceArrayRead(dctx, xin);

1599:       if (betaIsZero /* beta = 0 */) {
1600:         // here we can get away with purely write-only as we memcpy into it first
1601:         const auto   yptr = DeviceArrayWrite(dctx, yin);
1602:         cupmStream_t stream;

1604:         PetscCall(GetHandlesFrom_(dctx, &stream));
1605:         PetscCall(PetscLogGpuTimeBegin());
1606:         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1607:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1608:       } else {
1609:         const auto yptr = DeviceArrayReadWrite(dctx, yin);

1611:         PetscCall(PetscLogGpuTimeBegin());
1612:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1613:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1614:       }
1615:     }
1616:     PetscCall(PetscLogGpuTimeEnd());
1617:     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1618:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1619:   } else {
1620:     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1621:   }
1622:   PetscFunctionReturn(PETSC_SUCCESS);
1623: }

1625: // v->ops->axpby
1626: template <device::cupm::DeviceType T>
1627: inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1628: {
1629:   PetscFunctionBegin;
1630:   PetscCall(AXPBYAsync(yin, alpha, beta, xin, nullptr));
1631:   PetscFunctionReturn(PETSC_SUCCESS);
1632: }

1634: // VecAXPBYPCZAsync_Private
1635: template <device::cupm::DeviceType T>
1636: inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZAsync(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
1637: {
1638:   PetscFunctionBegin;
1639:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1640:   if (gamma != PetscScalar(1.0)) PetscCall(ScaleAsync(zin, gamma, dctx));
1641:   PetscCall(AXPYAsync(zin, alpha, xin, dctx));
1642:   PetscCall(AXPYAsync(zin, beta, yin, dctx));
1643:   PetscFunctionReturn(PETSC_SUCCESS);
1644: }

1646: // v->ops->axpbypcz
1647: template <device::cupm::DeviceType T>
1648: inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1649: {
1650:   PetscFunctionBegin;
1651:   PetscCall(AXPBYPCZAsync(zin, alpha, beta, gamma, xin, yin, nullptr));
1652:   PetscFunctionReturn(PETSC_SUCCESS);
1653: }

1655: // v->ops->norm
1656: template <device::cupm::DeviceType T>
1657: inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept
1658: {
1659:   PetscDeviceContext dctx;
1660:   cupmBlasHandle_t   cupmBlasHandle;

1662:   PetscFunctionBegin;
1663:   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1664:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1665:     const auto xptr      = DeviceArrayRead(dctx, xin);
1666:     PetscInt   flopCount = 0;

1668:     PetscCall(PetscLogGpuTimeBegin());
1669:     switch (type) {
1670:     case NORM_1_AND_2:
1671:     case NORM_1:
1672:       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1673:       flopCount = std::max(n - 1, 0);
1674:       if (type == NORM_1) break;
1675:       ++z; // fall-through
1676: #if PETSC_CPP_VERSION >= 17
1677:       [[fallthrough]];
1678: #endif
1679:     case NORM_2:
1680:     case NORM_FROBENIUS:
1681:       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1682:       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1683:       break;
1684:     case NORM_INFINITY: {
1685:       cupmBlasInt_t max_loc = 0;
1686:       PetscScalar   xv      = 0.;
1687:       cupmStream_t  stream;

1689:       PetscCall(GetHandlesFrom_(dctx, &stream));
1690:       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1691:       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1692:       *z = PetscAbsScalar(xv);
1693:       // REVIEW ME: flopCount = ???
1694:     } break;
1695:     }
1696:     PetscCall(PetscLogGpuTimeEnd());
1697:     PetscCall(PetscLogGpuFlops(flopCount));
1698:   } else {
1699:     z[0]                    = 0.0;
1700:     z[type == NORM_1_AND_2] = 0.0;
1701:   }
1702:   PetscFunctionReturn(PETSC_SUCCESS);
1703: }

1705: namespace detail
1706: {

1708: template <NormType wnormtype>
1709: class ErrorWNormTransformBase {
1710: public:
1711:   using result_type = thrust::tuple<PetscReal, PetscReal, PetscReal, PetscInt, PetscInt, PetscInt>;

1713:   constexpr explicit ErrorWNormTransformBase(PetscReal v) noexcept : ignore_max_{v} { }

1715: protected:
1716:   struct NormTuple {
1717:     PetscReal norm;
1718:     PetscInt  loc;
1719:   };

1721:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL static NormTuple compute_norm_(PetscReal err, PetscReal tol) noexcept
1722:   {
1723:     if (tol > 0.) {
1724:       const auto val = err / tol;

1726:       return {wnormtype == NORM_INFINITY ? val : PetscSqr(val), 1};
1727:     } else {
1728:       return {0.0, 0};
1729:     }
1730:   }

1732:   PetscReal ignore_max_;
1733: };

1735: template <NormType wnormtype>
1736: struct ErrorWNormTransform : ErrorWNormTransformBase<wnormtype> {
1737:   using base_type     = ErrorWNormTransformBase<wnormtype>;
1738:   using result_type   = typename base_type::result_type;
1739:   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>;

1741:   using base_type::base_type;

1743:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1744:   {
1745:     const auto u     = thrust::get<0>(x); // with x.get<0>(), cuda-12.4.0 gives error: class "cuda::std::__4::tuple" has no member "get"
1746:     const auto y     = thrust::get<1>(x);
1747:     const auto au    = PetscAbsScalar(u);
1748:     const auto ay    = PetscAbsScalar(y);
1749:     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1750:     const auto tola  = skip ? 0.0 : PetscRealPart(thrust::get<2>(x));
1751:     const auto tolr  = skip ? 0.0 : PetscRealPart(thrust::get<3>(x)) * PetscMax(au, ay);
1752:     const auto tol   = tola + tolr;
1753:     const auto err   = PetscAbsScalar(u - y);
1754:     const auto tup_a = this->compute_norm_(err, tola);
1755:     const auto tup_r = this->compute_norm_(err, tolr);
1756:     const auto tup_n = this->compute_norm_(err, tol);

1758:     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1759:   }
1760: };

1762: template <NormType wnormtype>
1763: struct ErrorWNormETransform : ErrorWNormTransformBase<wnormtype> {
1764:   using base_type     = ErrorWNormTransformBase<wnormtype>;
1765:   using result_type   = typename base_type::result_type;
1766:   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar, PetscScalar>;

1768:   using base_type::base_type;

1770:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1771:   {
1772:     const auto au    = PetscAbsScalar(thrust::get<0>(x));
1773:     const auto ay    = PetscAbsScalar(thrust::get<1>(x));
1774:     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1775:     const auto tola  = skip ? 0.0 : PetscRealPart(thrust::get<3>(x));
1776:     const auto tolr  = skip ? 0.0 : PetscRealPart(thrust::get<4>(x)) * PetscMax(au, ay);
1777:     const auto tol   = tola + tolr;
1778:     const auto err   = PetscAbsScalar(thrust::get<2>(x));
1779:     const auto tup_a = this->compute_norm_(err, tola);
1780:     const auto tup_r = this->compute_norm_(err, tolr);
1781:     const auto tup_n = this->compute_norm_(err, tol);

1783:     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1784:   }
1785: };

1787: template <NormType wnormtype>
1788: struct ErrorWNormReduce {
1789:   using value_type = typename ErrorWNormTransformBase<wnormtype>::result_type;

1791:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept
1792:   {
1793:     // cannot use lhs.get<0>() etc since the using decl above ambiguates the fact that
1794:     // result_type is a template, so in order to fix this we would need to write:
1795:     //
1796:     // lhs.template get<0>()
1797:     //
1798:     // which is unseemly.
1799:     if (wnormtype == NORM_INFINITY) {
1800:       // clang-format off
1801:       return {
1802:         PetscMax(thrust::get<0>(lhs), thrust::get<0>(rhs)),
1803:         PetscMax(thrust::get<1>(lhs), thrust::get<1>(rhs)),
1804:         PetscMax(thrust::get<2>(lhs), thrust::get<2>(rhs)),
1805:         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1806:         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1807:         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1808:       };
1809:       // clang-format on
1810:     } else {
1811:       // clang-format off
1812:       return {
1813:         thrust::get<0>(lhs) + thrust::get<0>(rhs),
1814:         thrust::get<1>(lhs) + thrust::get<1>(rhs),
1815:         thrust::get<2>(lhs) + thrust::get<2>(rhs),
1816:         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1817:         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1818:         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1819:       };
1820:       // clang-format on
1821:     }
1822:   }
1823: };

1825: template <template <NormType> class WNormTransformType, typename Tuple, typename cupmStream_t>
1826: inline PetscErrorCode ExecuteWNorm(Tuple &&first, Tuple &&last, NormType wnormtype, cupmStream_t stream, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1827: {
1828:   auto      begin = thrust::make_zip_iterator(std::forward<Tuple>(first));
1829:   auto      end   = thrust::make_zip_iterator(std::forward<Tuple>(last));
1830:   PetscReal n = 0, na = 0, nr = 0;
1831:   PetscInt  n_loc = 0, na_loc = 0, nr_loc = 0;

1833:   PetscFunctionBegin;
1834:   // clang-format off
1835:   if (wnormtype == NORM_INFINITY) {
1836:     PetscCallThrust(
1837:       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1838:         thrust::transform_reduce,
1839:         stream,
1840:         std::move(begin),
1841:         std::move(end),
1842:         WNormTransformType<NORM_INFINITY>{ignore_max},
1843:         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1844:         ErrorWNormReduce<NORM_INFINITY>{}
1845:       )
1846:     );
1847:   } else {
1848:     PetscCallThrust(
1849:       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1850:         thrust::transform_reduce,
1851:         stream,
1852:         std::move(begin),
1853:         std::move(end),
1854:         WNormTransformType<NORM_2>{ignore_max},
1855:         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1856:         ErrorWNormReduce<NORM_2>{}
1857:       )
1858:     );
1859:   }
1860:   // clang-format on
1861:   if (wnormtype == NORM_2) {
1862:     *norm  = PetscSqrtReal(*norm);
1863:     *norma = PetscSqrtReal(*norma);
1864:     *normr = PetscSqrtReal(*normr);
1865:   }
1866:   PetscFunctionReturn(PETSC_SUCCESS);
1867: }

1869: } // namespace detail

1871: // v->ops->errorwnorm
1872: template <device::cupm::DeviceType T>
1873: inline PetscErrorCode VecSeq_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1874: {
1875:   const auto         nl  = U->map->n;
1876:   auto               ait = thrust::make_constant_iterator(static_cast<PetscScalar>(atol));
1877:   auto               rit = thrust::make_constant_iterator(static_cast<PetscScalar>(rtol));
1878:   PetscDeviceContext dctx;
1879:   cupmStream_t       stream;

1881:   PetscFunctionBegin;
1882:   PetscCall(GetHandles_(&dctx, &stream));
1883:   {
1884:     const auto ConditionalDeviceArrayRead = [&](Vec v) {
1885:       if (v) {
1886:         return thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1887:       } else {
1888:         return thrust::device_ptr<PetscScalar>{nullptr};
1889:       }
1890:     };

1892:     const auto uarr = DeviceArrayRead(dctx, U);
1893:     const auto yarr = DeviceArrayRead(dctx, Y);
1894:     const auto uptr = thrust::device_pointer_cast(uarr.data());
1895:     const auto yptr = thrust::device_pointer_cast(yarr.data());
1896:     const auto eptr = ConditionalDeviceArrayRead(E);
1897:     const auto rptr = ConditionalDeviceArrayRead(vrtol);
1898:     const auto aptr = ConditionalDeviceArrayRead(vatol);

1900:     if (!vatol && !vrtol) {
1901:       if (E) {
1902:         // clang-format off
1903:         PetscCall(
1904:           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1905:             thrust::make_tuple(uptr, yptr, eptr, ait, rit),
1906:             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rit),
1907:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1908:           )
1909:         );
1910:         // clang-format on
1911:       } else {
1912:         // clang-format off
1913:         PetscCall(
1914:           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1915:             thrust::make_tuple(uptr, yptr, ait, rit),
1916:             thrust::make_tuple(uptr + nl, yptr + nl, ait, rit),
1917:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1918:           )
1919:         );
1920:         // clang-format on
1921:       }
1922:     } else if (!vatol) {
1923:       if (E) {
1924:         // clang-format off
1925:         PetscCall(
1926:           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1927:             thrust::make_tuple(uptr, yptr, eptr, ait, rptr),
1928:             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rptr + nl),
1929:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1930:           )
1931:         );
1932:         // clang-format on
1933:       } else {
1934:         // clang-format off
1935:         PetscCall(
1936:           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1937:             thrust::make_tuple(uptr, yptr, ait, rptr),
1938:             thrust::make_tuple(uptr + nl, yptr + nl, ait, rptr + nl),
1939:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1940:           )
1941:         );
1942:         // clang-format on
1943:       }
1944:     } else if (!vrtol) {
1945:       if (E) {
1946:         // clang-format off
1947:           PetscCall(
1948:             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1949:               thrust::make_tuple(uptr, yptr, eptr, aptr, rit),
1950:               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rit),
1951:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1952:             )
1953:           );
1954:         // clang-format on
1955:       } else {
1956:         // clang-format off
1957:           PetscCall(
1958:             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1959:               thrust::make_tuple(uptr, yptr, aptr, rit),
1960:               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rit),
1961:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1962:             )
1963:           );
1964:         // clang-format on
1965:       }
1966:     } else {
1967:       if (E) {
1968:         // clang-format off
1969:           PetscCall(
1970:             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1971:               thrust::make_tuple(uptr, yptr, eptr, aptr, rptr),
1972:               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rptr + nl),
1973:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1974:             )
1975:           );
1976:         // clang-format on
1977:       } else {
1978:         // clang-format off
1979:           PetscCall(
1980:             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1981:               thrust::make_tuple(uptr, yptr, aptr, rptr),
1982:               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rptr + nl),
1983:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1984:             )
1985:           );
1986:         // clang-format on
1987:       }
1988:     }
1989:   }
1990:   PetscFunctionReturn(PETSC_SUCCESS);
1991: }

1993: namespace detail
1994: {
1995: struct dotnorm2_mult {
1996:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1997:   {
1998:     const auto conjt = PetscConj(t);

2000:     return {s * conjt, t * conjt};
2001:   }
2002: };

2004: // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
2005: // would do it myself but now I am worried that they do so on purpose...
2006: struct dotnorm2_tuple_plus {
2007:   using value_type = thrust::tuple<PetscScalar, PetscScalar>;

2009:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {thrust::get<0>(lhs) + thrust::get<0>(rhs), thrust::get<1>(lhs) + thrust::get<1>(rhs)}; }
2010: };

2012: } // namespace detail

2014: // v->ops->dotnorm2
2015: template <device::cupm::DeviceType T>
2016: inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
2017: {
2018:   PetscDeviceContext dctx;
2019:   cupmStream_t       stream;

2021:   PetscFunctionBegin;
2022:   PetscCall(GetHandles_(&dctx, &stream));
2023:   {
2024:     PetscScalar dpt = 0.0, nmt = 0.0;
2025:     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());

2027:     // clang-format off
2028:     PetscCallThrust(
2029:       thrust::tie(*dp, *nm) = THRUST_CALL(
2030:         thrust::inner_product,
2031:         stream,
2032:         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
2033:         thrust::make_tuple(dpt, nmt),
2034:         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
2035:       );
2036:     );
2037:     // clang-format on
2038:   }
2039:   PetscFunctionReturn(PETSC_SUCCESS);
2040: }

2042: namespace detail
2043: {
2044: struct conjugate {
2045:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &x) const noexcept { return PetscConj(x); }
2046: };

2048: } // namespace detail

2050: // v->ops->conjugate
2051: template <device::cupm::DeviceType T>
2052: inline PetscErrorCode VecSeq_CUPM<T>::ConjugateAsync(Vec xin, PetscDeviceContext dctx) noexcept
2053: {
2054:   PetscFunctionBegin;
2055:   if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin, nullptr, dctx));
2056:   PetscFunctionReturn(PETSC_SUCCESS);
2057: }

2059: // v->ops->conjugate
2060: template <device::cupm::DeviceType T>
2061: inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept
2062: {
2063:   PetscFunctionBegin;
2064:   PetscCall(ConjugateAsync(xin, nullptr));
2065:   PetscFunctionReturn(PETSC_SUCCESS);
2066: }

2068: namespace detail
2069: {

2071: struct real_part {
2072:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const noexcept { return {PetscRealPart(thrust::get<0>(x)), thrust::get<1>(x)}; }

2074:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(const PetscScalar &x) const noexcept { return PetscRealPart(x); }
2075: };

2077: // deriving from Operator allows us to "store" an instance of the operator in the class but
2078: // also take advantage of empty base class optimization if the operator is stateless
2079: template <typename Operator>
2080: class tuple_compare : Operator {
2081: public:
2082:   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
2083:   using operator_type = Operator;

2085:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
2086:   {
2087:     if (op_()(thrust::get<0>(y), thrust::get<0>(x))) {
2088:       // if y is strictly greater/less than x, return y
2089:       return y;
2090:     } else if (thrust::get<0>(y) == thrust::get<0>(x)) {
2091:       // if equal, prefer lower index
2092:       return thrust::get<1>(y) < thrust::get<1>(x) ? y : x;
2093:     }
2094:     // otherwise return x
2095:     return x;
2096:   }

2098: private:
2099:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
2100: };

2102: } // namespace detail

2104: template <device::cupm::DeviceType T>
2105: template <typename TupleFuncT, typename UnaryFuncT>
2106: inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
2107: {
2108:   PetscFunctionBegin;
2109:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
2110:   if (p) *p = -1;
2111:   if (const auto n = v->map->n) {
2112:     PetscDeviceContext dctx;
2113:     cupmStream_t       stream;

2115:     PetscCall(GetHandles_(&dctx, &stream));
2116:     // needed to:
2117:     // 1. switch between transform_reduce and reduce
2118:     // 2. strip the real_part functor from the arguments
2119: #if PetscDefined(USE_COMPLEX)
2120:   #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
2121: #else
2122:   #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
2123: #endif
2124:     {
2125:       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());

2127:       if (p) {
2128:         // clang-format off
2129:         const auto zip = thrust::make_zip_iterator(
2130:           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
2131:         );
2132:         // clang-format on
2133:         // need to use preprocessor conditionals since otherwise thrust complains about not being
2134:         // able to convert a thrust::device_reference to a PetscReal on complex
2135:         // builds...
2136:         // clang-format off
2137:         PetscCallThrust(
2138:           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
2139:             stream, zip, zip + n, detail::real_part{},
2140:             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
2141:           );
2142:         );
2143:         // clang-format on
2144:       } else {
2145:         // clang-format off
2146:         PetscCallThrust(
2147:           *m = THRUST_MINMAX_REDUCE(
2148:             stream, vptr, vptr + n, detail::real_part{},
2149:             *m, std::forward<UnaryFuncT>(unary_ftr)
2150:           );
2151:         );
2152:         // clang-format on
2153:       }
2154:     }
2155: #undef THRUST_MINMAX_REDUCE
2156:   }
2157:   // REVIEW ME: flops?
2158:   PetscFunctionReturn(PETSC_SUCCESS);
2159: }

2161: // v->ops->max
2162: template <device::cupm::DeviceType T>
2163: inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept
2164: {
2165: #if CCCL_VERSION >= 3001000
2166:   using tuple_functor = detail::tuple_compare<cuda::std::greater<PetscReal>>;
2167:   using unary_functor = cuda::maximum<PetscReal>;
2168: #else
2169:   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
2170:   using unary_functor = thrust::maximum<PetscReal>;
2171: #endif

2173:   PetscFunctionBegin;
2174:   *m = PETSC_MIN_REAL;
2175:   // use {} constructor syntax otherwise most vexing parse
2176:   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
2177:   PetscFunctionReturn(PETSC_SUCCESS);
2178: }

2180: // v->ops->min
2181: template <device::cupm::DeviceType T>
2182: inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept
2183: {
2184: #if CCCL_VERSION >= 3001000
2185:   using tuple_functor = detail::tuple_compare<cuda::std::less<PetscReal>>;
2186:   using unary_functor = cuda::minimum<PetscReal>;
2187: #else
2188:   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
2189:   using unary_functor = thrust::minimum<PetscReal>;
2190: #endif

2192:   PetscFunctionBegin;
2193:   *m = PETSC_MAX_REAL;
2194:   // use {} constructor syntax otherwise most vexing parse
2195:   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
2196:   PetscFunctionReturn(PETSC_SUCCESS);
2197: }

2199: // v->ops->sum
2200: template <device::cupm::DeviceType T>
2201: inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept
2202: {
2203:   PetscFunctionBegin;
2204:   if (const auto n = v->map->n) {
2205:     PetscDeviceContext dctx;
2206:     cupmStream_t       stream;

2208:     PetscCall(GetHandles_(&dctx, &stream));
2209:     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
2210:     // REVIEW ME: why not cupmBlasXasum()?
2211:     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
2212:     // REVIEW ME: must be at least n additions
2213:     PetscCall(PetscLogGpuFlops(n));
2214:   } else {
2215:     *sum = 0.0;
2216:   }
2217:   PetscFunctionReturn(PETSC_SUCCESS);
2218: }

2220: template <device::cupm::DeviceType T>
2221: inline PetscErrorCode VecSeq_CUPM<T>::ShiftAsync(Vec v, PetscScalar shift, PetscDeviceContext dctx) noexcept
2222: {
2223:   PetscFunctionBegin;
2224:   PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v, nullptr, dctx));
2225:   PetscFunctionReturn(PETSC_SUCCESS);
2226: }

2228: template <device::cupm::DeviceType T>
2229: inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept
2230: {
2231:   PetscFunctionBegin;
2232:   PetscCall(ShiftAsync(v, shift, nullptr));
2233:   PetscFunctionReturn(PETSC_SUCCESS);
2234: }

2236: template <device::cupm::DeviceType T>
2237: inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept
2238: {
2239:   PetscFunctionBegin;
2240:   if (const auto n = v->map->n) {
2241:     PetscBool          iscurand;
2242:     PetscDeviceContext dctx;

2244:     PetscCall(GetHandles_(&dctx));
2245:     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
2246:     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
2247:     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
2248:   } else {
2249:     PetscCall(MaybeIncrementEmptyLocalVec(v));
2250:   }
2251:   // REVIEW ME: flops????
2252:   // REVIEW ME: Timing???
2253:   PetscFunctionReturn(PETSC_SUCCESS);
2254: }

2256: // v->ops->setpreallocation
2257: template <device::cupm::DeviceType T>
2258: inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
2259: {
2260:   PetscDeviceContext dctx;

2262:   PetscFunctionBegin;
2263:   PetscCall(GetHandles_(&dctx));
2264:   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
2265:   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
2266:   PetscFunctionReturn(PETSC_SUCCESS);
2267: }

2269: // v->ops->setvaluescoo
2270: template <device::cupm::DeviceType T>
2271: inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
2272: {
2273:   auto               vv = const_cast<PetscScalar *>(v);
2274:   PetscMemType       memtype;
2275:   PetscDeviceContext dctx;
2276:   cupmStream_t       stream;

2278:   PetscFunctionBegin;
2279:   PetscCall(GetHandles_(&dctx, &stream));
2280:   PetscCall(PetscGetMemType(v, &memtype));
2281:   if (PetscMemTypeHost(memtype)) {
2282:     const auto size = VecIMPLCast(x)->coo_n;

2284:     // If user gave v[] in host, we might need to copy it to device if any
2285:     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
2286:     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
2287:   }

2289:   if (const auto n = x->map->n) {
2290:     const auto vcu = VecCUPMCast(x);

2292:     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
2293:   } else {
2294:     PetscCall(MaybeIncrementEmptyLocalVec(x));
2295:   }

2297:   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
2298:   PetscCall(PetscDeviceContextSynchronize(dctx));
2299:   PetscFunctionReturn(PETSC_SUCCESS);
2300: }

2302: } // namespace impl

2304: } // namespace cupm

2306: } // namespace vec

2308: } // namespace Petsc