Actual source code: dmceed.c

  1: #include <petsc/private/dmimpl.h>
  2: #include <petscdmceed.h>

  4: #ifdef PETSC_HAVE_LIBCEED
  5: #include <petsc/private/dmpleximpl.h>
  6: #include <petscdmplexceed.h>
  7: #include <petscfeceed.h>

  9: /*@C
 10:   DMGetCeed - Get the LibCEED context associated with this `DM`

 12:   Not Collective

 14:   Input Parameter:
 15: . DM   - The `DM`

 17:   Output Parameter:
 18: . ceed - The LibCEED context

 20:   Level: intermediate

 22: .seealso: `DM`, `DMCreate()`
 23: @*/
 24: PetscErrorCode DMGetCeed(DM dm, Ceed *ceed)
 25: {
 26:   PetscFunctionBegin;
 28:   PetscAssertPointer(ceed, 2);
 29:   if (!dm->ceed) {
 30:     char        ceedresource[PETSC_MAX_PATH_LEN]; /* libCEED resource specifier */
 31:     const char *prefix;

 33:     PetscCall(PetscStrncpy(ceedresource, "/cpu/self", sizeof(ceedresource)));
 34:     PetscCall(PetscObjectGetOptionsPrefix((PetscObject)dm, &prefix));
 35:     PetscCall(PetscOptionsGetString(NULL, prefix, "-dm_ceed", ceedresource, sizeof(ceedresource), NULL));
 36:     PetscCallCEED(CeedInit(ceedresource, &dm->ceed));
 37:   }
 38:   *ceed = dm->ceed;
 39:   PetscFunctionReturn(PETSC_SUCCESS);
 40: }

 42: static CeedMemType PetscMemType2Ceed(PetscMemType mem_type)
 43: {
 44:   return PetscMemTypeDevice(mem_type) ? CEED_MEM_DEVICE : CEED_MEM_HOST;
 45: }

 47: PetscErrorCode VecGetCeedVector(Vec X, Ceed ceed, CeedVector *cx)
 48: {
 49:   PetscMemType memtype;
 50:   PetscScalar *x;
 51:   PetscInt     n;

 53:   PetscFunctionBegin;
 54:   PetscCall(VecGetLocalSize(X, &n));
 55:   PetscCall(VecGetArrayAndMemType(X, &x, &memtype));
 56:   PetscCallCEED(CeedVectorCreate(ceed, n, cx));
 57:   PetscCallCEED(CeedVectorSetArray(*cx, PetscMemType2Ceed(memtype), CEED_USE_POINTER, x));
 58:   PetscFunctionReturn(PETSC_SUCCESS);
 59: }

 61: PetscErrorCode VecRestoreCeedVector(Vec X, CeedVector *cx)
 62: {
 63:   PetscFunctionBegin;
 64:   PetscCall(VecRestoreArrayAndMemType(X, NULL));
 65:   PetscCallCEED(CeedVectorDestroy(cx));
 66:   PetscFunctionReturn(PETSC_SUCCESS);
 67: }

 69: PetscErrorCode VecGetCeedVectorRead(Vec X, Ceed ceed, CeedVector *cx)
 70: {
 71:   PetscMemType       memtype;
 72:   const PetscScalar *x;
 73:   PetscInt           n;
 74:   PetscFunctionBegin;
 75:   PetscCall(VecGetLocalSize(X, &n));
 76:   PetscCall(VecGetArrayReadAndMemType(X, &x, &memtype));
 77:   PetscCallCEED(CeedVectorCreate(ceed, n, cx));
 78:   PetscCallCEED(CeedVectorSetArray(*cx, PetscMemType2Ceed(memtype), CEED_USE_POINTER, (PetscScalar *)x));
 79:   PetscFunctionReturn(PETSC_SUCCESS);
 80: }

 82: PetscErrorCode VecRestoreCeedVectorRead(Vec X, CeedVector *cx)
 83: {
 84:   PetscFunctionBegin;
 85:   PetscCall(VecRestoreArrayReadAndMemType(X, NULL));
 86:   PetscCallCEED(CeedVectorDestroy(cx));
 87:   PetscFunctionReturn(PETSC_SUCCESS);
 88: }

 90: CEED_QFUNCTION(Geometry2D)(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out)
 91: {
 92:   const CeedScalar *x = in[0], *Jac = in[1], *w = in[2];
 93:   CeedScalar       *qdata = out[0];

 95:   CeedPragmaSIMD for (CeedInt i = 0; i < Q; ++i)
 96:   {
 97:     const CeedScalar J[2][2] = {
 98:       {Jac[i + Q * 0], Jac[i + Q * 2]},
 99:       {Jac[i + Q * 1], Jac[i + Q * 3]}
100:     };
101:     const CeedScalar det = J[0][0] * J[1][1] - J[0][1] * J[1][0];

103:     qdata[i + Q * 0] = det * w[i];
104:     qdata[i + Q * 1] = x[i + Q * 0];
105:     qdata[i + Q * 2] = x[i + Q * 1];
106:     qdata[i + Q * 3] = J[1][1] / det;
107:     qdata[i + Q * 4] = -J[1][0] / det;
108:     qdata[i + Q * 5] = -J[0][1] / det;
109:     qdata[i + Q * 6] = J[0][0] / det;
110:   }
111:   return CEED_ERROR_SUCCESS;
112: }

114: CEED_QFUNCTION(Geometry3D)(void *ctx, const CeedInt Q, const CeedScalar *const *in, CeedScalar *const *out)
115: {
116:   const CeedScalar *Jac = in[1], *w = in[2];
117:   CeedScalar       *qdata = out[0];

119:   CeedPragmaSIMD for (CeedInt i = 0; i < Q; ++i)
120:   {
121:     const CeedScalar J[3][3] = {
122:       {Jac[i + Q * 0], Jac[i + Q * 3], Jac[i + Q * 6]},
123:       {Jac[i + Q * 1], Jac[i + Q * 4], Jac[i + Q * 7]},
124:       {Jac[i + Q * 2], Jac[i + Q * 5], Jac[i + Q * 8]}
125:     };
126:     const CeedScalar det = J[0][0] * (J[1][1] * J[2][2] - J[1][2] * J[2][1]) + J[0][1] * (J[1][2] * J[2][0] - J[1][0] * J[2][2]) + J[0][2] * (J[1][0] * J[2][1] - J[1][1] * J[2][0]);

128:     qdata[i + Q * 0] = det * w[i]; /* det J * weight */
129:   }
130:   return CEED_ERROR_SUCCESS;
131: }

133: static PetscErrorCode DMCeedCreateGeometry(DM dm, IS cellIS, PetscInt *Nqdata, CeedElemRestriction *erq, CeedVector *qd, DMCeed *soldata)
134: {
135:   Ceed              ceed;
136:   DMCeed            sd;
137:   PetscDS           ds;
138:   PetscFE           fe;
139:   CeedQFunctionUser geom     = NULL;
140:   const char       *geomName = NULL;
141:   const PetscInt   *cells;
142:   PetscInt          dim, cdim, cStart, cEnd, Ncell, Nq;

144:   PetscFunctionBegin;
145:   PetscCall(PetscCalloc1(1, &sd));
146:   PetscCall(DMGetDimension(dm, &dim));
147:   PetscCall(DMGetCoordinateDim(dm, &cdim));
148:   PetscCall(DMGetCeed(dm, &ceed));
149:   PetscCall(ISGetPointRange(cellIS, &cStart, &cEnd, &cells));
150:   Ncell = cEnd - cStart;

152:   PetscCall(DMGetCellDS(dm, cells ? cells[cStart] : cStart, &ds, NULL));
153:   PetscCall(PetscDSGetDiscretization(ds, 0, (PetscObject *)&fe));
154:   PetscCall(PetscFEGetCeedBasis(fe, &sd->basis));
155:   PetscCall(CeedBasisGetNumQuadraturePoints(sd->basis, &Nq));
156:   PetscCall(DMPlexGetCeedRestriction(dm, NULL, 0, 0, 0, &sd->er));

158:   *Nqdata = 1 + cdim + cdim * dim;
159:   PetscCallCEED(CeedElemRestrictionCreateStrided(ceed, Ncell, Nq, *Nqdata, Ncell * Nq * (*Nqdata), CEED_STRIDES_BACKEND, erq));

161:   switch (dim) {
162:   case 2:
163:     geom     = Geometry2D;
164:     geomName = Geometry2D_loc;
165:     break;
166:   case 3:
167:     geom     = Geometry3D;
168:     geomName = Geometry3D_loc;
169:     break;
170:   }
171:   PetscCallCEED(CeedQFunctionCreateInterior(ceed, 1, geom, geomName, &sd->qf));
172:   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "x", cdim, CEED_EVAL_INTERP));
173:   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "dx", cdim * dim, CEED_EVAL_GRAD));
174:   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "weight", 1, CEED_EVAL_WEIGHT));
175:   PetscCallCEED(CeedQFunctionAddOutput(sd->qf, "qdata", *Nqdata, CEED_EVAL_NONE));

177:   PetscCallCEED(CeedOperatorCreate(ceed, sd->qf, CEED_QFUNCTION_NONE, CEED_QFUNCTION_NONE, &sd->op));
178:   PetscCallCEED(CeedOperatorSetField(sd->op, "x", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
179:   PetscCallCEED(CeedOperatorSetField(sd->op, "dx", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
180:   PetscCallCEED(CeedOperatorSetField(sd->op, "weight", CEED_ELEMRESTRICTION_NONE, sd->basis, CEED_VECTOR_NONE));
181:   PetscCallCEED(CeedOperatorSetField(sd->op, "qdata", *erq, CEED_BASIS_COLLOCATED, CEED_VECTOR_ACTIVE));

183:   PetscCallCEED(CeedElemRestrictionCreateVector(*erq, qd, NULL));
184:   *soldata = sd;
185:   PetscFunctionReturn(PETSC_SUCCESS);
186: }

188: PetscErrorCode DMRefineHook_Ceed(DM coarse, DM fine, void *ctx)
189: {
190:   PetscFunctionBegin;
191:   if (coarse->dmceed) PetscCall(DMCeedCreate(fine, coarse->dmceed->geom ? PETSC_TRUE : PETSC_FALSE, coarse->dmceed->func, coarse->dmceed->funcSource));
192:   PetscFunctionReturn(PETSC_SUCCESS);
193: }

195: PetscErrorCode DMCeedCreate_Internal(DM dm, IS cellIS, PetscBool createGeometry, CeedQFunctionUser func, const char *func_source, DMCeed *soldata)
196: {
197:   PetscDS  ds;
198:   PetscFE  fe;
199:   DMCeed   sd;
200:   Ceed     ceed;
201:   PetscInt dim, Nc, Nq, Nqdata = 0;

203:   PetscFunctionBegin;
204:   PetscCall(PetscCalloc1(1, &sd));
205:   PetscCall(DMGetDimension(dm, &dim));
206:   PetscCall(DMGetCeed(dm, &ceed));
207:   PetscCall(DMGetDS(dm, &ds));
208:   PetscCall(PetscDSGetDiscretization(ds, 0, (PetscObject *)&fe));
209:   PetscCall(PetscFEGetCeedBasis(fe, &sd->basis));
210:   PetscCall(PetscFEGetNumComponents(fe, &Nc));
211:   PetscCall(CeedBasisGetNumQuadraturePoints(sd->basis, &Nq));
212:   PetscCall(DMPlexGetCeedRestriction(dm, NULL, 0, 0, 0, &sd->er));

214:   if (createGeometry) {
215:     DM cdm;

217:     PetscCall(DMGetCoordinateDM(dm, &cdm));
218:     PetscCall(DMCeedCreateGeometry(cdm, cellIS, &Nqdata, &sd->erq, &sd->qd, &sd->geom));
219:   }

221:   if (sd->geom) {
222:     PetscInt cdim, Nqx;

224:     PetscCallCEED(CeedBasisGetNumQuadraturePoints(sd->geom->basis, &Nqx));
225:     PetscCheck(Nqx == Nq, PetscObjectComm((PetscObject)dm), PETSC_ERR_ARG_INCOMP, "Number of qpoints for solution %" PetscInt_FMT " != %" PetscInt_FMT " Number of qpoints for coordinates", Nq, Nqx);
226:     /* TODO Remove this limitation */
227:     PetscCall(DMGetCoordinateDim(dm, &cdim));
228:     PetscCheck(dim == cdim, PetscObjectComm((PetscObject)dm), PETSC_ERR_ARG_INCOMP, "Topological dimension %" PetscInt_FMT " != %" PetscInt_FMT " embedding dimension", dim, cdim);
229:   }

231:   PetscCallCEED(CeedQFunctionCreateInterior(ceed, 1, func, func_source, &sd->qf));
232:   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "u", Nc, CEED_EVAL_INTERP));
233:   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "du", Nc * dim, CEED_EVAL_GRAD));
234:   PetscCallCEED(CeedQFunctionAddInput(sd->qf, "qdata", Nqdata, CEED_EVAL_NONE));
235:   PetscCallCEED(CeedQFunctionAddOutput(sd->qf, "v", Nc, CEED_EVAL_INTERP));
236:   PetscCallCEED(CeedQFunctionAddOutput(sd->qf, "dv", Nc * dim, CEED_EVAL_GRAD));

238:   PetscCallCEED(CeedOperatorCreate(ceed, sd->qf, CEED_QFUNCTION_NONE, CEED_QFUNCTION_NONE, &sd->op));
239:   PetscCallCEED(CeedOperatorSetField(sd->op, "u", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
240:   PetscCallCEED(CeedOperatorSetField(sd->op, "du", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
241:   PetscCallCEED(CeedOperatorSetField(sd->op, "qdata", sd->erq, CEED_BASIS_COLLOCATED, sd->qd));
242:   PetscCallCEED(CeedOperatorSetField(sd->op, "v", sd->er, sd->basis, CEED_VECTOR_ACTIVE));
243:   PetscCallCEED(CeedOperatorSetField(sd->op, "dv", sd->er, sd->basis, CEED_VECTOR_ACTIVE));

245:   // Handle refinement
246:   sd->func = func;
247:   PetscCall(PetscStrallocpy(func_source, &sd->funcSource));
248:   PetscCall(DMRefineHookAdd(dm, DMRefineHook_Ceed, NULL, NULL));

250:   *soldata = sd;
251:   PetscFunctionReturn(PETSC_SUCCESS);
252: }

254: PetscErrorCode DMCeedCreate(DM dm, PetscBool createGeometry, CeedQFunctionUser func, const char *func_source)
255: {
256:   DM plex;
257:   IS cellIS;

259:   PetscFunctionBegin;
260:   PetscCall(DMConvert(dm, DMPLEX, &plex));
261:   PetscCall(DMPlexGetAllCells_Internal(plex, &cellIS));
262:   #ifdef PETSC_HAVE_LIBCEED
263:   PetscCall(DMCeedCreate_Internal(dm, cellIS, createGeometry, func, func_source, &dm->dmceed));
264:   #endif
265:   PetscCall(ISDestroy(&cellIS));
266:   PetscCall(DMDestroy(&plex));
267:   PetscFunctionReturn(PETSC_SUCCESS);
268: }

270: #endif

272: PetscErrorCode DMCeedDestroy(DMCeed *pceed)
273: {
274:   DMCeed p = *pceed;

276:   PetscFunctionBegin;
277:   if (!p) PetscFunctionReturn(PETSC_SUCCESS);
278: #ifdef PETSC_HAVE_LIBCEED
279:   PetscCall(PetscFree(p->funcSource));
280:   if (p->qd) PetscCallCEED(CeedVectorDestroy(&p->qd));
281:   if (p->op) PetscCallCEED(CeedOperatorDestroy(&p->op));
282:   if (p->qf) PetscCallCEED(CeedQFunctionDestroy(&p->qf));
283:   if (p->erq) PetscCallCEED(CeedElemRestrictionDestroy(&p->erq));
284:   PetscCall(DMCeedDestroy(&p->geom));
285: #endif
286:   PetscCall(PetscFree(p));
287:   *pceed = NULL;
288:   PetscFunctionReturn(PETSC_SUCCESS);
289: }

291: PetscErrorCode DMCeedComputeGeometry(DM dm, DMCeed sd)
292: {
293: #ifdef PETSC_HAVE_LIBCEED
294:   Ceed       ceed;
295:   Vec        coords;
296:   CeedVector ccoords;
297: #endif

299:   PetscFunctionBegin;
300: #ifdef PETSC_HAVE_LIBCEED
301:   PetscCall(DMGetCeed(dm, &ceed));
302:   PetscCall(DMGetCoordinatesLocal(dm, &coords));
303:   PetscCall(VecGetCeedVectorRead(coords, ceed, &ccoords));
304:   PetscCallCEED(CeedOperatorApply(sd->geom->op, ccoords, sd->qd, CEED_REQUEST_IMMEDIATE));
305:   PetscCall(VecRestoreCeedVectorRead(coords, &ccoords));
306: #endif
307:   PetscFunctionReturn(PETSC_SUCCESS);
308: }