Actual source code: snespatch.c

petsc-3.12.5 2020-03-29
Report Typos and Errors
  1: /*
  2:       Defines a SNES that can consist of a collection of SNESes on patches of the domain
  3: */
  4:  #include <petsc/private/snesimpl.h>
  5: #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */
  6:  #include <petscsf.h>
  7:  #include <petscsection.h>

  9: typedef struct {
 10:   PC pc; /* The linear patch preconditioner */
 11: } SNES_Patch;

 13: static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
 14: {
 15:   PC                pc      = (PC) ctx;
 16:   PC_PATCH          *pcpatch = (PC_PATCH *) pc->data;
 17:   PetscInt          pt, size, i;
 18:   const PetscInt    *indices;
 19:   const PetscScalar *X;
 20:   PetscScalar       *XWithAll;
 21:   PetscErrorCode    ierr;


 25:   /* scatter from x to patch->patchStateWithAll[pt] */
 26:   pt = pcpatch->currentPatch;
 27:   ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);

 29:   ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
 30:   VecGetArrayRead(x, &X);
 31:   VecGetArray(pcpatch->patchStateWithAll[pt], &XWithAll);

 33:   for (i = 0; i < size; ++i) {
 34:     XWithAll[indices[i]] = X[i];
 35:   }

 37:   VecRestoreArray(pcpatch->patchStateWithAll[pt], &XWithAll);
 38:   VecRestoreArrayRead(x, &X);
 39:   ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);

 41:   PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll[pt], F, pt);
 42:   return(0);
 43: }

 45: static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
 46: {
 47:   PC                pc      = (PC) ctx;
 48:   PC_PATCH          *pcpatch = (PC_PATCH *) pc->data;
 49:   PetscInt          pt, size, i;
 50:   const PetscInt    *indices;
 51:   const PetscScalar *X;
 52:   PetscScalar       *XWithAll;
 53:   PetscErrorCode    ierr;

 56:   /* scatter from x to patch->patchStateWithAll[pt] */
 57:   pt = pcpatch->currentPatch;
 58:   ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);

 60:   ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
 61:   VecGetArrayRead(x, &X);
 62:   VecGetArray(pcpatch->patchStateWithAll[pt], &XWithAll);

 64:   for (i = 0; i < size; ++i) {
 65:     XWithAll[indices[i]] = X[i];
 66:   }

 68:   VecRestoreArray(pcpatch->patchStateWithAll[pt], &XWithAll);
 69:   VecRestoreArrayRead(x, &X);
 70:   ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);

 72:   PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll[pt], M, pcpatch->currentPatch, PETSC_FALSE);
 73:   return(0);
 74: }

 76: static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
 77: {
 78:   PC_PATCH       *patch = (PC_PATCH *) pc->data;
 79:   const char     *prefix;
 80:   PetscInt       i, pStart, dof;

 84:   if (!pc->setupcalled) {
 85:     PetscMalloc1(patch->npatch, &patch->solver);
 86:     PCGetOptionsPrefix(pc, &prefix);
 87:     PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
 88:     for (i = 0; i < patch->npatch; ++i) {
 89:       SNES snes;

 91:       SNESCreate(PETSC_COMM_SELF, &snes);
 92:       SNESSetOptionsPrefix(snes, prefix);
 93:       SNESAppendOptionsPrefix(snes, "sub_");
 94:       PetscObjectIncrementTabLevel((PetscObject) snes, (PetscObject) pc, 2);
 95:       PetscLogObjectParent((PetscObject) pc, (PetscObject) snes);
 96:       patch->solver[i] = (PetscObject) snes;
 97:     }

 99:     PetscMalloc1(patch->npatch, &patch->patchResidual);
100:     PetscMalloc1(patch->npatch, &patch->patchState);
101:     PetscMalloc1(patch->npatch, &patch->patchStateWithAll);
102:     for (i = 0; i < patch->npatch; ++i) {
103:       VecDuplicate(patch->patchRHS[i], &patch->patchResidual[i]);
104:       VecDuplicate(patch->patchUpdate[i], &patch->patchState[i]);

106:       PetscSectionGetDof(patch->gtolCountsWithAll, i+pStart, &dof);
107:       VecCreateSeq(PETSC_COMM_SELF, dof, &patch->patchStateWithAll[i]);
108:       VecSetUp(patch->patchStateWithAll[i]);
109:     }
110:     VecDuplicate(patch->localUpdate, &patch->localState);
111:   }
112:   for (i = 0; i < patch->npatch; ++i) {
113:     SNES snes = (SNES) patch->solver[i];

115:     SNESSetFunction(snes, patch->patchResidual[i], SNESPatchComputeResidual_Private, pc);
116:     SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc);
117:   }
118:   if (!pc->setupcalled && patch->optionsSet) for (i = 0; i < patch->npatch; ++i) {SNESSetFromOptions((SNES) patch->solver[i]);}
119:   return(0);
120: }

122: static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
123: {
124:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
125:   PetscInt       pStart;

129:   patch->currentPatch = i;
130:   PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0);

132:   /* Scatter the overlapped global state to our patch state vector */
133:   PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
134:   PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchState[i], INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR);
135:   PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchStateWithAll[i], INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL);

137:   /* Set initial guess to be current state*/
138:   VecCopy(patch->patchState[i], patchUpdate);
139:   /* Solve for new state */
140:   SNESSolve((SNES) patch->solver[i], patchRHS, patchUpdate);
141:   /* To compute update, subtract off previous state */
142:   VecAXPY(patchUpdate, -1.0, patch->patchState[i]);

144:   PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0);
145:   return(0);
146: }

148: static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
149: {
150:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
151:   PetscInt       i;

155:   if (patch->solver) {
156:     for (i = 0; i < patch->npatch; ++i) {SNESReset((SNES) patch->solver[i]);}
157:   }

159:   if (patch->patchResidual) {
160:     for (i = 0; i < patch->npatch; ++i) {VecDestroy(&patch->patchResidual[i]);}
161:     PetscFree(patch->patchResidual);
162:   }

164:   if (patch->patchState) {
165:     for (i = 0; i < patch->npatch; ++i) {VecDestroy(&patch->patchState[i]);}
166:     PetscFree(patch->patchState);
167:   }

169:   if (patch->patchStateWithAll) {
170:     for (i = 0; i < patch->npatch; ++i) {VecDestroy(&patch->patchStateWithAll[i]);}
171:     PetscFree(patch->patchStateWithAll);
172:   }

174:   VecDestroy(&patch->localState);
175:   return(0);
176: }

178: static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
179: {
180:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
181:   PetscInt       i;

185:   if (patch->solver) {
186:     for (i = 0; i < patch->npatch; ++i) {SNESDestroy((SNES *) &patch->solver[i]);}
187:     PetscFree(patch->solver);
188:   }
189:   return(0);
190: }

192: static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
193: {
194:   PC_PATCH      *patch = (PC_PATCH *) pc->data;

198:   PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate[i], patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR);
199:   return(0);
200: }

202: static PetscErrorCode SNESSetUp_Patch(SNES snes)
203: {
204:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
205:   DM             dm;
206:   Mat            dummy;
207:   Vec            F;
208:   PetscInt       n, N;

212:   SNESGetDM(snes, &dm);
213:   PCSetDM(patch->pc, dm);
214:   SNESGetFunction(snes, &F, NULL, NULL);
215:   VecGetLocalSize(F, &n);
216:   VecGetSize(F, &N);
217:   MatCreateShell(PetscObjectComm((PetscObject) snes), n, n, N, N, (void *) snes, &dummy);
218:   PCSetOperators(patch->pc, dummy, dummy);
219:   MatDestroy(&dummy);
220:   PCSetUp(patch->pc);
221:   /* allocate workspace */
222:   return(0);
223: }

225: static PetscErrorCode SNESReset_Patch(SNES snes)
226: {
227:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

231:   PCReset(patch->pc);
232:   return(0);
233: }

235: static PetscErrorCode SNESDestroy_Patch(SNES snes)
236: {
237:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

241:   SNESReset_Patch(snes);
242:   PCDestroy(&patch->pc);
243:   PetscFree(snes->data);
244:   return(0);
245: }

247: static PetscErrorCode SNESSetFromOptions_Patch(PetscOptionItems *PetscOptionsObject, SNES snes)
248: {
249:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
250:   const char    *prefix;

254:   PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix);
255:   PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix);
256:   PCSetFromOptions(patch->pc);
257:   return(0);
258: }

260: static PetscErrorCode SNESView_Patch(SNES snes,PetscViewer viewer)
261: {
262:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
263:   PetscBool      iascii;

267:   PetscObjectTypeCompare((PetscObject) viewer, PETSCVIEWERASCII, &iascii);
268:   if (iascii) {
269:     PetscViewerASCIIPrintf(viewer,"SNESPATCH\n");
270:   }
271:   PetscViewerASCIIPushTab(viewer);
272:   PCView(patch->pc, viewer);
273:   PetscViewerASCIIPopTab(viewer);
274:   return(0);
275: }

277: static PetscErrorCode SNESSolve_Patch(SNES snes)
278: {
279:   SNES_Patch        *patch = (SNES_Patch *) snes->data;
280:   PC_PATCH          *pcpatch = (PC_PATCH *) patch->pc->data;
281:   SNESLineSearch    ls;
282:   Vec               rhs, update, state, residual;
283:   const PetscScalar *globalState  = NULL;
284:   PetscScalar       *localState   = NULL;
285:   PetscInt          its = 0;
286:   PetscReal         xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
287:   PetscErrorCode    ierr;

290:   SNESGetSolution(snes, &state);
291:   SNESGetSolutionUpdate(snes, &update);
292:   SNESGetRhs(snes, &rhs);

294:   SNESGetFunction(snes, &residual, NULL, NULL);
295:   SNESGetLineSearch(snes, &ls);

297:   SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING);
298:   VecSet(update, 0.0);
299:   SNESComputeFunction(snes, state, residual);

301:   VecNorm(state, NORM_2, &xnorm);
302:   VecNorm(residual, NORM_2, &fnorm);
303:   snes->ttol = fnorm*snes->rtol;

305:   if (snes->ops->converged) {
306:     (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
307:   } else {
308:     SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,0);
309:   }
310:   SNESLogConvergenceHistory(snes, fnorm, 0); /* should we count lits from the patches? */
311:   SNESMonitor(snes, its, fnorm);

313:   /* The main solver loop */
314:   for (its = 0; its < snes->max_its; its++) {

316:     SNESSetIterationNumber(snes, its);

318:     /* Scatter state vector to overlapped vector on all patches.
319:        The vector pcpatch->localState is scattered to each patch
320:        in PCApply_PATCH_Nonlinear. */
321:     VecGetArrayRead(state, &globalState);
322:     VecGetArray(pcpatch->localState, &localState);
323:     PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState);
324:     PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState);
325:     VecRestoreArray(pcpatch->localState, &localState);
326:     VecRestoreArrayRead(state, &globalState);

328:     /* The looping over patches happens here */
329:     PCApply(patch->pc, rhs, update);

331:     /* Apply a line search. This will often be basic with
332:        damping = 1/(max number of patches a dof can be in),
333:        but not always */
334:     VecScale(update, -1.0);
335:     SNESLineSearchApply(ls, state, residual, &fnorm, update);

337:     VecNorm(state, NORM_2, &xnorm);
338:     VecNorm(update, NORM_2, &ynorm);

340:     if (snes->ops->converged) {
341:       (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
342:     } else {
343:       SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,0);
344:     }
345:     SNESLogConvergenceHistory(snes, fnorm, 0); /* FIXME: should we count lits? */
346:     SNESMonitor(snes, its, fnorm);
347:   }

349:   if (its == snes->max_its) { SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT); }
350:   return(0);
351: }

353: /*MC
354:   SNESPATCH - Solve a nonlinear problem by composing together many nonlinear solvers on patches

356:   Level: intermediate

358: .seealso:  SNESCreate(), SNESSetType(), SNESType (for list of available types), SNES,
359:            PCPATCH

361:    References:
362: .  1. - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015

364: M*/
365: PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
366: {
368:   SNES_Patch     *patch;
369:   PC_PATCH       *patchpc;
370:   SNESLineSearch linesearch;

373:   PetscNewLog(snes, &patch);

375:   snes->ops->solve          = SNESSolve_Patch;
376:   snes->ops->setup          = SNESSetUp_Patch;
377:   snes->ops->reset          = SNESReset_Patch;
378:   snes->ops->destroy        = SNESDestroy_Patch;
379:   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
380:   snes->ops->view           = SNESView_Patch;

382:   SNESGetLineSearch(snes,&linesearch);
383:   SNESLineSearchSetType(linesearch,SNESLINESEARCHBASIC);
384:   snes->usesksp        = PETSC_FALSE;

386:   snes->alwayscomputesfinalresidual = PETSC_FALSE;

388:   snes->data = (void *) patch;
389:   PCCreate(PetscObjectComm((PetscObject) snes), &patch->pc);
390:   PCSetType(patch->pc, PCPATCH);

392:   patchpc = (PC_PATCH*) patch->pc->data;
393:   patchpc->classname = "snes";
394:   patchpc->isNonlinear = PETSC_TRUE;

396:   patchpc->setupsolver   = PCSetUp_PATCH_Nonlinear;
397:   patchpc->applysolver   = PCApply_PATCH_Nonlinear;
398:   patchpc->resetsolver   = PCReset_PATCH_Nonlinear;
399:   patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
400:   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;

402:   return(0);
403: }

405: PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap,
406:                                             const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
407: {
408:   SNES_Patch     *patch = (SNES_Patch *) snes->data;
410:   DM             dm;

413:   SNESGetDM(snes, &dm);
414:   if (!dm) SETERRQ(PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES\n");
415:   PCSetDM(patch->pc, dm);
416:   PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes);
417:   return(0);
418: }

420: PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
421: {
422:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

426:   PCPatchSetComputeOperator(patch->pc, func, ctx);
427:   return(0);
428: }

430: PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
431: {
432:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

436:   PCPatchSetComputeFunction(patch->pc, func, ctx);
437:   return(0);
438: }

440: PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
441: {
442:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

446:   PCPatchSetConstructType(patch->pc, ctype, func, ctx);
447:   return(0);
448: }

450: PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
451: {
452:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

456:   PCPatchSetCellNumbering(patch->pc, cellNumbering);
457:   return(0);
458: }