Actual source code: vecmpicupm_impl.hpp

  1: #pragma once

  3: #include "vecmpicupm.hpp"

  5: #include <../src/sys/objects/device/impls/cupm/kernels.hpp>

  7: #include <petsc/private/sfimpl.h>

  9: namespace Petsc
 10: {

 12: namespace vec
 13: {

 15: namespace cupm
 16: {

 18: namespace impl
 19: {

 21: template <device::cupm::DeviceType T>
 22: inline Vec_MPI *VecMPI_CUPM<T>::VecIMPLCast_(Vec v) noexcept
 23: {
 24:   return static_cast<Vec_MPI *>(v->data);
 25: }

 27: template <device::cupm::DeviceType T>
 28: inline constexpr VecType VecMPI_CUPM<T>::VECIMPLCUPM_() noexcept
 29: {
 30:   return VECMPICUPM();
 31: }

 33: template <device::cupm::DeviceType T>
 34: inline constexpr VecType VecMPI_CUPM<T>::VECIMPL_() noexcept
 35: {
 36:   return VECMPI;
 37: }

 39: template <device::cupm::DeviceType T>
 40: inline PetscErrorCode VecMPI_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
 41: {
 42:   PetscFunctionBegin;
 43:   PetscCall(VecSeq_T::ClearAsyncFunctions(v));
 44:   PetscCall(VecDestroy_MPI(v));
 45:   PetscFunctionReturn(PETSC_SUCCESS);
 46: }

 48: template <device::cupm::DeviceType T>
 49: inline PetscErrorCode VecMPI_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
 50: {
 51:   return VecResetArray_MPI(v);
 52: }

 54: template <device::cupm::DeviceType T>
 55: inline PetscErrorCode VecMPI_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
 56: {
 57:   return VecPlaceArray_MPI(v, a);
 58: }

 60: template <device::cupm::DeviceType T>
 61: inline PetscErrorCode VecMPI_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt nghost, PetscScalar *) noexcept
 62: {
 63:   PetscFunctionBegin;
 64:   if (alloc_missing) *alloc_missing = PETSC_TRUE;
 65:   // note host_array is always ignored, we never create it as part of the construction sequence
 66:   // for VecMPI since we always want to either allocate it ourselves with pinned memory or set
 67:   // it in Initialize_CUPMBase()
 68:   PetscCall(VecCreate_MPI_Private(v, PETSC_FALSE, nghost, nullptr));
 69:   PetscCall(VecSeq_T::InitializeAsyncFunctions(v));
 70:   PetscFunctionReturn(PETSC_SUCCESS);
 71: }

 73: template <device::cupm::DeviceType T>
 74: inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPM_(Vec v, PetscDeviceContext dctx, PetscBool allocate_missing, PetscInt nghost, PetscScalar *host_array, PetscScalar *device_array) noexcept
 75: {
 76:   PetscFunctionBegin;
 77:   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, nghost));
 78:   PetscCall(Initialize_CUPMBase(v, allocate_missing, host_array, device_array, dctx));
 79:   PetscFunctionReturn(PETSC_SUCCESS);
 80: }

 82: // ================================================================================== //
 83: //                                                                                    //
 84: //                                  public methods                                    //
 85: //                                                                                    //
 86: // ================================================================================== //

 88: // ================================================================================== //
 89: //                             constructors/destructors                               //

 91: // VecCreateMPICUPM()
 92: template <device::cupm::DeviceType T>
 93: inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPM(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, Vec *v, PetscBool call_set_type) noexcept
 94: {
 95:   PetscFunctionBegin;
 96:   PetscCall(Create_CUPMBase(comm, bs, n, N, v, call_set_type));
 97:   PetscFunctionReturn(PETSC_SUCCESS);
 98: }

100: // VecCreateMPICUPMWithArray[s]()
101: template <device::cupm::DeviceType T>
102: inline PetscErrorCode VecMPI_CUPM<T>::CreateMPICUPMWithArrays(MPI_Comm comm, PetscInt bs, PetscInt n, PetscInt N, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
103: {
104:   PetscDeviceContext dctx;

106:   PetscFunctionBegin;
107:   PetscCall(GetHandles_(&dctx));
108:   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
109:   // CreateMPICUPM_() is called!
110:   PetscCall(CreateMPICUPM(comm, bs, n, N, v, PETSC_FALSE));
111:   PetscCall(CreateMPICUPM_(*v, dctx, PETSC_FALSE, 0, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
112:   PetscFunctionReturn(PETSC_SUCCESS);
113: }

115: // v->ops->duplicate
116: template <device::cupm::DeviceType T>
117: inline PetscErrorCode VecMPI_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
118: {
119:   const auto         vimpl  = VecIMPLCast(v);
120:   const auto         nghost = vimpl->nghost;
121:   PetscDeviceContext dctx;

123:   PetscFunctionBegin;
124:   PetscCall(GetHandles_(&dctx));
125:   // does not call VecSetType(), we set up the data structures ourselves
126:   PetscCall(Duplicate_CUPMBase(v, y, dctx, [=](Vec z) { return CreateMPICUPM_(z, dctx, PETSC_FALSE, nghost); }));

128:   /* save local representation of the parallel vector (and scatter) if it exists */
129:   if (const auto locrep = vimpl->localrep) {
130:     const auto   yimpl   = VecIMPLCast(*y);
131:     auto        &ylocrep = yimpl->localrep;
132:     PetscScalar *array;

134:     PetscCall(VecGetArray(*y, &array));
135:     PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, v->map->bs, v->map->n + nghost, array, &ylocrep));
136:     PetscCall(VecRestoreArray(*y, &array));
137:     ylocrep->ops[0] = locrep->ops[0];
138:     if (const auto scatter = (yimpl->localupdate = vimpl->localupdate)) PetscCall(PetscObjectReference(PetscObjectCast(scatter)));

140:     yimpl->ghost = vimpl->ghost;
141:     PetscCall(PetscObjectReference((PetscObject)yimpl->ghost));
142:   }
143:   PetscFunctionReturn(PETSC_SUCCESS);
144: }

146: // v->ops->bintocpu
147: template <device::cupm::DeviceType T>
148: inline PetscErrorCode VecMPI_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
149: {
150:   PetscDeviceContext dctx;

152:   PetscFunctionBegin;
153:   PetscCall(GetHandles_(&dctx));
154:   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));

156:   VecSetOp_CUPM(dot, VecDot_MPI, Dot);
157:   VecSetOp_CUPM(mdot, VecMDot_MPI, MDot);
158:   VecSetOp_CUPM(norm, VecNorm_MPI, Norm);
159:   VecSetOp_CUPM(tdot, VecTDot_MPI, TDot);
160:   VecSetOp_CUPM(resetarray, VecResetArray_MPI, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
161:   VecSetOp_CUPM(placearray, VecPlaceArray_MPI, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
162:   VecSetOp_CUPM(max, VecMax_MPI, Max);
163:   VecSetOp_CUPM(min, VecMin_MPI, Min);
164:   PetscFunctionReturn(PETSC_SUCCESS);
165: }

167: // ================================================================================== //
168: //                                   compute methods                                  //

170: template <device::cupm::DeviceType T>
171: inline PetscErrorCode VecMPI_CUPM<T>::Norm(Vec v, NormType type, PetscReal *z) noexcept
172: {
173:   PetscFunctionBegin;
174:   PetscCall(VecNorm_MPI_Default(v, type, z, VecSeq_T::Norm));
175:   PetscFunctionReturn(PETSC_SUCCESS);
176: }

178: template <device::cupm::DeviceType T>
179: inline PetscErrorCode VecMPI_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
180: {
181:   PetscFunctionBegin;
182:   PetscCall(VecErrorWeightedNorms_MPI_Default(U, Y, E, wnormtype, atol, vatol, rtol, vrtol, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc, VecSeq_T::ErrorWnorm));
183:   PetscFunctionReturn(PETSC_SUCCESS);
184: }

186: template <device::cupm::DeviceType T>
187: inline PetscErrorCode VecMPI_CUPM<T>::Dot(Vec x, Vec y, PetscScalar *z) noexcept
188: {
189:   PetscFunctionBegin;
190:   PetscCall(VecXDot_MPI_Default(x, y, z, VecSeq_T::Dot));
191:   PetscFunctionReturn(PETSC_SUCCESS);
192: }

194: template <device::cupm::DeviceType T>
195: inline PetscErrorCode VecMPI_CUPM<T>::TDot(Vec x, Vec y, PetscScalar *z) noexcept
196: {
197:   PetscFunctionBegin;
198:   PetscCall(VecXDot_MPI_Default(x, y, z, VecSeq_T::TDot));
199:   PetscFunctionReturn(PETSC_SUCCESS);
200: }

202: template <device::cupm::DeviceType T>
203: inline PetscErrorCode VecMPI_CUPM<T>::MDot(Vec x, PetscInt nv, const Vec y[], PetscScalar *z) noexcept
204: {
205:   PetscFunctionBegin;
206:   PetscCall(VecMXDot_MPI_Default(x, nv, y, z, VecSeq_T::MDot));
207:   PetscFunctionReturn(PETSC_SUCCESS);
208: }

210: template <device::cupm::DeviceType T>
211: inline PetscErrorCode VecMPI_CUPM<T>::DotNorm2(Vec x, Vec y, PetscScalar *dp, PetscScalar *nm) noexcept
212: {
213:   PetscFunctionBegin;
214:   PetscCall(VecDotNorm2_MPI_Default(x, y, dp, nm, VecSeq_T::DotNorm2));
215:   PetscFunctionReturn(PETSC_SUCCESS);
216: }

218: template <device::cupm::DeviceType T>
219: inline PetscErrorCode VecMPI_CUPM<T>::Max(Vec x, PetscInt *idx, PetscReal *z) noexcept
220: {
221:   const MPI_Op ops[] = {MPIU_MAXLOC, MPIU_MAX};

223:   PetscFunctionBegin;
224:   PetscCall(VecMinMax_MPI_Default(x, idx, z, VecSeq_T::Max, ops));
225:   PetscFunctionReturn(PETSC_SUCCESS);
226: }

228: template <device::cupm::DeviceType T>
229: inline PetscErrorCode VecMPI_CUPM<T>::Min(Vec x, PetscInt *idx, PetscReal *z) noexcept
230: {
231:   const MPI_Op ops[] = {MPIU_MINLOC, MPIU_MIN};

233:   PetscFunctionBegin;
234:   PetscCall(VecMinMax_MPI_Default(x, idx, z, VecSeq_T::Min, ops));
235:   PetscFunctionReturn(PETSC_SUCCESS);
236: }

238: template <device::cupm::DeviceType T>
239: inline PetscErrorCode VecMPI_CUPM<T>::SetPreallocationCOO(Vec x, PetscCount ncoo, const PetscInt coo_i[]) noexcept
240: {
241:   PetscDeviceContext dctx;

243:   PetscFunctionBegin;
244:   PetscCall(GetHandles_(&dctx));
245:   PetscCall(VecSetPreallocationCOO_MPI(x, ncoo, coo_i));
246:   // both of these must exist for this to work
247:   PetscCall(VecCUPMAllocateCheck_(x));
248:   {
249:     const auto vcu  = VecCUPMCast(x);
250:     const auto vmpi = VecIMPLCast(x);

252:     // clang-format off
253:     PetscCall(
254:       SetPreallocationCOO_CUPMBase(
255:         x, ncoo, coo_i, dctx,
256:         util::make_array(
257:           make_coo_pair(vcu->imap2_d, vmpi->imap2, vmpi->nnz2),
258:           make_coo_pair(vcu->jmap2_d, vmpi->jmap2, vmpi->nnz2 + 1),
259:           make_coo_pair(vcu->perm2_d, vmpi->perm2, vmpi->recvlen),
260:           make_coo_pair(vcu->Cperm_d, vmpi->Cperm, vmpi->sendlen)
261:         ),
262:         util::make_array(
263:           make_coo_pair(vcu->sendbuf_d, vmpi->sendbuf, vmpi->sendlen),
264:           make_coo_pair(vcu->recvbuf_d, vmpi->recvbuf, vmpi->recvlen)
265:         )
266:       )
267:     );
268:     // clang-format on
269:   }
270:   PetscFunctionReturn(PETSC_SUCCESS);
271: }

273: namespace kernels
274: {

276: namespace
277: {

279: PETSC_KERNEL_DECL void pack_coo_values(const PetscScalar *PETSC_RESTRICT vv, PetscCount nnz, const PetscCount *PETSC_RESTRICT perm, PetscScalar *PETSC_RESTRICT buf)
280: {
281:   Petsc::device::cupm::kernels::util::grid_stride_1D(nnz, [=](PetscCount i) { buf[i] = vv[perm[i]]; });
282:   return;
283: }

285: PETSC_KERNEL_DECL void add_remote_coo_values(const PetscScalar *PETSC_RESTRICT vv, PetscCount nnz2, const PetscCount *PETSC_RESTRICT imap2, const PetscCount *PETSC_RESTRICT jmap2, const PetscCount *PETSC_RESTRICT perm2, PetscScalar *PETSC_RESTRICT xv)
286: {
287:   add_coo_values_impl(vv, nnz2, jmap2, perm2, ADD_VALUES, xv, [=](PetscCount i) { return imap2[i]; });
288:   return;
289: }

291: } // namespace

293: #if PetscDefined(USING_HCC)
294: namespace do_not_use
295: {

297: // Needed to silence clang warning:
298: //
299: // warning: function 'FUNCTION NAME' is not needed and will not be emitted
300: //
301: // The warning is silly, since the function *is* used, however the host compiler does not
302: // appear see this. Likely because the function using it is in a template.
303: //
304: // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
305: inline void silence_warning_function_pack_coo_values_is_not_needed_and_will_not_be_emitted()
306: {
307:   (void)pack_coo_values;
308: }

310: inline void silence_warning_function_add_remote_coo_values_is_not_needed_and_will_not_be_emitted()
311: {
312:   (void)add_remote_coo_values;
313: }

315: } // namespace do_not_use
316: #endif

318: } // namespace kernels

320: template <device::cupm::DeviceType T>
321: inline PetscErrorCode VecMPI_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
322: {
323:   PetscDeviceContext dctx;
324:   PetscMemType       v_memtype;
325:   cupmStream_t       stream;

327:   PetscFunctionBegin;
328:   PetscCall(GetHandles_(&dctx, &stream));
329:   PetscCall(PetscGetMemType(v, &v_memtype));
330:   {
331:     const auto vmpi      = VecIMPLCast(x);
332:     const auto vcu       = VecCUPMCast(x);
333:     const auto sf        = vmpi->coo_sf;
334:     const auto sendbuf_d = vcu->sendbuf_d;
335:     const auto recvbuf_d = vcu->recvbuf_d;
336:     const auto xv        = imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data();
337:     auto       vv        = const_cast<PetscScalar *>(v);

339:     if (PetscMemTypeHost(v_memtype)) {
340:       const auto size = vmpi->coo_n;

342:       /* If user gave v[] in host, we might need to copy it to device if any */
343:       PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
344:       PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
345:     }

347:     /* Pack entries to be sent to remote */
348:     if (const auto sendlen = vmpi->sendlen) {
349:       PetscCall(PetscCUPMLaunchKernel1D(sendlen, 0, stream, kernels::pack_coo_values, vv, sendlen, vcu->Cperm_d, sendbuf_d));
350:       // need to sync up here since we are about to send this to petscsf
351:       // REVIEW ME: no we dont, sf just needs to learn to use PetscDeviceContext
352:       PetscCallCUPM(cupmStreamSynchronize(stream));
353:     }

355:     PetscCall(PetscSFReduceWithMemTypeBegin(sf, MPIU_SCALAR, PETSC_MEMTYPE_CUPM(), sendbuf_d, PETSC_MEMTYPE_CUPM(), recvbuf_d, MPI_REPLACE));

357:     if (const auto n = x->map->n) PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, xv));

359:     PetscCall(PetscSFReduceEnd(sf, MPIU_SCALAR, sendbuf_d, recvbuf_d, MPI_REPLACE));

361:     /* Add received remote entries */
362:     if (const auto nnz2 = vmpi->nnz2) PetscCall(PetscCUPMLaunchKernel1D(nnz2, 0, stream, kernels::add_remote_coo_values, recvbuf_d, nnz2, vcu->imap2_d, vcu->jmap2_d, vcu->perm2_d, xv));

364:     if (PetscMemTypeHost(v_memtype)) PetscCall(PetscDeviceFree(dctx, vv));
365:     PetscCall(PetscDeviceContextSynchronize(dctx));
366:   }
367:   PetscFunctionReturn(PETSC_SUCCESS);
368: }

370: } // namespace impl

372: } // namespace cupm

374: } // namespace vec

376: } // namespace Petsc