Actual source code: snespatch.c

petsc-3.13.6 2020-09-29
Report Typos and Errors
  1: /*
  2:       Defines a SNES that can consist of a collection of SNESes on patches of the domain
  3: */
  4:  #include <petsc/private/vecimpl.h>
  5:  #include <petsc/private/snesimpl.h>
  6:  #include <petsc/private/pcpatchimpl.h>
  7:  #include <petscsf.h>
  8:  #include <petscsection.h>

 10: typedef struct {
 11:   PC pc; /* The linear patch preconditioner */
 12: } SNES_Patch;

 14: static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
 15: {
 16:   PC                pc      = (PC) ctx;
 17:   PC_PATCH          *pcpatch = (PC_PATCH *) pc->data;
 18:   PetscInt          pt, size, i;
 19:   const PetscInt    *indices;
 20:   const PetscScalar *X;
 21:   PetscScalar       *XWithAll;
 22:   PetscErrorCode    ierr;


 26:   /* scatter from x to patch->patchStateWithAll[pt] */
 27:   pt = pcpatch->currentPatch;
 28:   ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);

 30:   ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
 31:   VecGetArrayRead(x, &X);
 32:   VecGetArray(pcpatch->patchStateWithAll, &XWithAll);

 34:   for (i = 0; i < size; ++i) {
 35:     XWithAll[indices[i]] = X[i];
 36:   }

 38:   VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll);
 39:   VecRestoreArrayRead(x, &X);
 40:   ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);

 42:   PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt);
 43:   return(0);
 44: }

 46: static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
 47: {
 48:   PC                pc      = (PC) ctx;
 49:   PC_PATCH          *pcpatch = (PC_PATCH *) pc->data;
 50:   PetscInt          pt, size, i;
 51:   const PetscInt    *indices;
 52:   const PetscScalar *X;
 53:   PetscScalar       *XWithAll;
 54:   PetscErrorCode    ierr;

 57:   /* scatter from x to patch->patchStateWithAll[pt] */
 58:   pt = pcpatch->currentPatch;
 59:   ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);

 61:   ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
 62:   VecGetArrayRead(x, &X);
 63:   VecGetArray(pcpatch->patchStateWithAll, &XWithAll);

 65:   for (i = 0; i < size; ++i) {
 66:     XWithAll[indices[i]] = X[i];
 67:   }

 69:   VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll);
 70:   VecRestoreArrayRead(x, &X);
 71:   ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);

 73:   PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE);
 74:   return(0);
 75: }

 77: static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
 78: {
 79:   PC_PATCH       *patch = (PC_PATCH *) pc->data;
 80:   const char     *prefix;
 81:   PetscInt       i, pStart, dof, maxDof = -1;

 85:   if (!pc->setupcalled) {
 86:     PetscMalloc1(patch->npatch, &patch->solver);
 87:     PCGetOptionsPrefix(pc, &prefix);
 88:     PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
 89:     for (i = 0; i < patch->npatch; ++i) {
 90:       SNES snes;

 92:       SNESCreate(PETSC_COMM_SELF, &snes);
 93:       SNESSetOptionsPrefix(snes, prefix);
 94:       SNESAppendOptionsPrefix(snes, "sub_");
 95:       PetscObjectIncrementTabLevel((PetscObject) snes, (PetscObject) pc, 2);
 96:       PetscLogObjectParent((PetscObject) pc, (PetscObject) snes);
 97:       patch->solver[i] = (PetscObject) snes;

 99:       PetscSectionGetDof(patch->gtolCountsWithAll, i+pStart, &dof);
100:       maxDof = PetscMax(maxDof, dof);
101:     }
102:     VecDuplicate(patch->localUpdate, &patch->localState);
103:     VecDuplicate(patch->patchRHS, &patch->patchResidual);
104:     VecDuplicate(patch->patchUpdate, &patch->patchState);

106:     VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll);
107:     VecSetUp(patch->patchStateWithAll);
108:   }
109:   for (i = 0; i < patch->npatch; ++i) {
110:     SNES snes = (SNES) patch->solver[i];

112:     SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc);
113:     SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc);
114:   }
115:   if (!pc->setupcalled && patch->optionsSet) for (i = 0; i < patch->npatch; ++i) {SNESSetFromOptions((SNES) patch->solver[i]);}
116:   return(0);
117: }

119: static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
120: {
121:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
122:   PetscInt       pStart, n;

126:   patch->currentPatch = i;
127:   PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0);

129:   /* Scatter the overlapped global state to our patch state vector */
130:   PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
131:   PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR);
132:   PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL);

134:   MatGetLocalSize(patch->mat[i], NULL, &n);
135:   patch->patchState->map->n = n;
136:   patch->patchState->map->N = n;
137:   patchUpdate->map->n = n;
138:   patchUpdate->map->N = n;
139:   patchRHS->map->n = n;
140:   patchRHS->map->N = n;
141:   /* Set initial guess to be current state*/
142:   VecCopy(patch->patchState, patchUpdate);
143:   /* Solve for new state */
144:   SNESSolve((SNES) patch->solver[i], patchRHS, patchUpdate);
145:   /* To compute update, subtract off previous state */
146:   VecAXPY(patchUpdate, -1.0, patch->patchState);

148:   PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0);
149:   return(0);
150: }

152: static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
153: {
154:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
155:   PetscInt       i;

159:   if (patch->solver) {
160:     for (i = 0; i < patch->npatch; ++i) {SNESReset((SNES) patch->solver[i]);}
161:   }

163:   VecDestroy(&patch->patchResidual);
164:   VecDestroy(&patch->patchState);
165:   VecDestroy(&patch->patchStateWithAll);

167:   VecDestroy(&patch->localState);
168:   return(0);
169: }

171: static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
172: {
173:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
174:   PetscInt       i;

178:   if (patch->solver) {
179:     for (i = 0; i < patch->npatch; ++i) {SNESDestroy((SNES *) &patch->solver[i]);}
180:     PetscFree(patch->solver);
181:   }
182:   return(0);
183: }

185: static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
186: {
187:   PC_PATCH      *patch = (PC_PATCH *) pc->data;

191:   PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR);
192:   return(0);
193: }

195: static PetscErrorCode SNESSetUp_Patch(SNES snes)
196: {
197:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
198:   DM             dm;
199:   Mat            dummy;
200:   Vec            F;
201:   PetscInt       n, N;

205:   SNESGetDM(snes, &dm);
206:   PCSetDM(patch->pc, dm);
207:   SNESGetFunction(snes, &F, NULL, NULL);
208:   VecGetLocalSize(F, &n);
209:   VecGetSize(F, &N);
210:   MatCreateShell(PetscObjectComm((PetscObject) snes), n, n, N, N, (void *) snes, &dummy);
211:   PCSetOperators(patch->pc, dummy, dummy);
212:   MatDestroy(&dummy);
213:   PCSetUp(patch->pc);
214:   /* allocate workspace */
215:   return(0);
216: }

218: static PetscErrorCode SNESReset_Patch(SNES snes)
219: {
220:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

224:   PCReset(patch->pc);
225:   return(0);
226: }

228: static PetscErrorCode SNESDestroy_Patch(SNES snes)
229: {
230:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

234:   SNESReset_Patch(snes);
235:   PCDestroy(&patch->pc);
236:   PetscFree(snes->data);
237:   return(0);
238: }

240: static PetscErrorCode SNESSetFromOptions_Patch(PetscOptionItems *PetscOptionsObject, SNES snes)
241: {
242:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
243:   const char    *prefix;

247:   PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix);
248:   PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix);
249:   PCSetFromOptions(patch->pc);
250:   return(0);
251: }

253: static PetscErrorCode SNESView_Patch(SNES snes,PetscViewer viewer)
254: {
255:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
256:   PetscBool      iascii;

260:   PetscObjectTypeCompare((PetscObject) viewer, PETSCVIEWERASCII, &iascii);
261:   if (iascii) {
262:     PetscViewerASCIIPrintf(viewer,"SNESPATCH\n");
263:   }
264:   PetscViewerASCIIPushTab(viewer);
265:   PCView(patch->pc, viewer);
266:   PetscViewerASCIIPopTab(viewer);
267:   return(0);
268: }

270: static PetscErrorCode SNESSolve_Patch(SNES snes)
271: {
272:   SNES_Patch        *patch = (SNES_Patch *) snes->data;
273:   PC_PATCH          *pcpatch = (PC_PATCH *) patch->pc->data;
274:   SNESLineSearch    ls;
275:   Vec               rhs, update, state, residual;
276:   const PetscScalar *globalState  = NULL;
277:   PetscScalar       *localState   = NULL;
278:   PetscInt          its = 0;
279:   PetscReal         xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
280:   PetscErrorCode    ierr;

283:   SNESGetSolution(snes, &state);
284:   SNESGetSolutionUpdate(snes, &update);
285:   SNESGetRhs(snes, &rhs);

287:   SNESGetFunction(snes, &residual, NULL, NULL);
288:   SNESGetLineSearch(snes, &ls);

290:   SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING);
291:   VecSet(update, 0.0);
292:   SNESComputeFunction(snes, state, residual);

294:   VecNorm(state, NORM_2, &xnorm);
295:   VecNorm(residual, NORM_2, &fnorm);
296:   snes->ttol = fnorm*snes->rtol;

298:   if (snes->ops->converged) {
299:     (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
300:   } else {
301:     SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,0);
302:   }
303:   SNESLogConvergenceHistory(snes, fnorm, 0); /* should we count lits from the patches? */
304:   SNESMonitor(snes, its, fnorm);

306:   /* The main solver loop */
307:   for (its = 0; its < snes->max_its; its++) {

309:     SNESSetIterationNumber(snes, its);

311:     /* Scatter state vector to overlapped vector on all patches.
312:        The vector pcpatch->localState is scattered to each patch
313:        in PCApply_PATCH_Nonlinear. */
314:     VecGetArrayRead(state, &globalState);
315:     VecGetArray(pcpatch->localState, &localState);
316:     PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState);
317:     PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState);
318:     VecRestoreArray(pcpatch->localState, &localState);
319:     VecRestoreArrayRead(state, &globalState);

321:     /* The looping over patches happens here */
322:     PCApply(patch->pc, rhs, update);

324:     /* Apply a line search. This will often be basic with
325:        damping = 1/(max number of patches a dof can be in),
326:        but not always */
327:     VecScale(update, -1.0);
328:     SNESLineSearchApply(ls, state, residual, &fnorm, update);

330:     VecNorm(state, NORM_2, &xnorm);
331:     VecNorm(update, NORM_2, &ynorm);

333:     if (snes->ops->converged) {
334:       (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
335:     } else {
336:       SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,0);
337:     }
338:     SNESLogConvergenceHistory(snes, fnorm, 0); /* FIXME: should we count lits? */
339:     SNESMonitor(snes, its, fnorm);
340:   }

342:   if (its == snes->max_its) { SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT); }
343:   return(0);
344: }

346: /*MC
347:   SNESPATCH - Solve a nonlinear problem by composing together many nonlinear solvers on patches

349:   Level: intermediate

351: .seealso:  SNESCreate(), SNESSetType(), SNESType (for list of available types), SNES,
352:            PCPATCH

354:    References:
355: .  1. - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015

357: M*/
358: PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
359: {
361:   SNES_Patch     *patch;
362:   PC_PATCH       *patchpc;
363:   SNESLineSearch linesearch;

366:   PetscNewLog(snes, &patch);

368:   snes->ops->solve          = SNESSolve_Patch;
369:   snes->ops->setup          = SNESSetUp_Patch;
370:   snes->ops->reset          = SNESReset_Patch;
371:   snes->ops->destroy        = SNESDestroy_Patch;
372:   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
373:   snes->ops->view           = SNESView_Patch;

375:   SNESGetLineSearch(snes,&linesearch);
376:   if (!((PetscObject)linesearch)->type_name) {
377:     SNESLineSearchSetType(linesearch,SNESLINESEARCHBASIC);
378:   }
379:   snes->usesksp        = PETSC_FALSE;

381:   snes->alwayscomputesfinalresidual = PETSC_FALSE;

383:   snes->data = (void *) patch;
384:   PCCreate(PetscObjectComm((PetscObject) snes), &patch->pc);
385:   PCSetType(patch->pc, PCPATCH);

387:   patchpc = (PC_PATCH*) patch->pc->data;
388:   patchpc->classname = "snes";
389:   patchpc->isNonlinear = PETSC_TRUE;

391:   patchpc->setupsolver   = PCSetUp_PATCH_Nonlinear;
392:   patchpc->applysolver   = PCApply_PATCH_Nonlinear;
393:   patchpc->resetsolver   = PCReset_PATCH_Nonlinear;
394:   patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
395:   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;

397:   return(0);
398: }

400: PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap,
401:                                             const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
402: {
403:   SNES_Patch     *patch = (SNES_Patch *) snes->data;
405:   DM             dm;

408:   SNESGetDM(snes, &dm);
409:   if (!dm) SETERRQ(PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES\n");
410:   PCSetDM(patch->pc, dm);
411:   PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes);
412:   return(0);
413: }

415: PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
416: {
417:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

421:   PCPatchSetComputeOperator(patch->pc, func, ctx);
422:   return(0);
423: }

425: PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
426: {
427:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

431:   PCPatchSetComputeFunction(patch->pc, func, ctx);
432:   return(0);
433: }

435: PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
436: {
437:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

441:   PCPatchSetConstructType(patch->pc, ctype, func, ctx);
442:   return(0);
443: }

445: PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
446: {
447:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

451:   PCPatchSetCellNumbering(patch->pc, cellNumbering);
452:   return(0);
453: }