Actual source code: snespatch.c

petsc-3.11.4 2019-09-28
Report Typos and Errors
  1: /*
  2:       Defines a SNES that can consist of a collection of SNESes on patches of the domain
  3: */
  4:  #include <petsc/private/snesimpl.h>
  5: #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */
  6:  #include <petscsf.h>

  8: typedef struct {
  9:   PC pc; /* The linear patch preconditioner */
 10: } SNES_Patch;

 12: static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
 13: {
 14:   PC             pc      = (PC) ctx;
 15:   PC_PATCH      *pcpatch = (PC_PATCH *) pc->data;
 16:   PetscInt       pt, size, i;
 17:   const PetscInt *indices;
 18:   const PetscScalar *X;
 19:   PetscScalar   *XWithAll;


 24:   /* scatter from x to patch->patchStateWithAll[pt] */
 25:   pt = pcpatch->currentPatch;
 26:   ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);

 28:   ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
 29:   VecGetArrayRead(x, &X);
 30:   VecGetArray(pcpatch->patchStateWithAll[pt], &XWithAll);

 32:   for (i = 0; i < size; ++i) {
 33:     XWithAll[indices[i]] = X[i];
 34:   }

 36:   VecRestoreArray(pcpatch->patchStateWithAll[pt], &XWithAll);
 37:   VecRestoreArrayRead(x, &X);
 38:   ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);

 40:   PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll[pt], F, pt);
 41:   return(0);
 42: }

 44: static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
 45: {
 46:   PC             pc      = (PC) ctx;
 47:   PC_PATCH      *pcpatch = (PC_PATCH *) pc->data;
 48:   PetscInt       pt, size, i;
 49:   const PetscInt *indices;
 50:   const PetscScalar *X;
 51:   PetscScalar   *XWithAll;

 55:   /* scatter from x to patch->patchStateWithAll[pt] */
 56:   pt = pcpatch->currentPatch;
 57:   ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);

 59:   ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
 60:   VecGetArrayRead(x, &X);
 61:   VecGetArray(pcpatch->patchStateWithAll[pt], &XWithAll);

 63:   for (i = 0; i < size; ++i) {
 64:     XWithAll[indices[i]] = X[i];
 65:   }

 67:   VecRestoreArray(pcpatch->patchStateWithAll[pt], &XWithAll);
 68:   VecRestoreArrayRead(x, &X);
 69:   ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);

 71:   PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll[pt], M, pcpatch->currentPatch, PETSC_FALSE);
 72:   return(0);
 73: }

 75: static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
 76: {
 77:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
 78:   const char    *prefix;
 79:   PetscInt       i, pStart, dof;

 83:   if (!pc->setupcalled) {
 84:     PetscMalloc1(patch->npatch, &patch->solver);
 85:     PCGetOptionsPrefix(pc, &prefix);
 86:     PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
 87:     for (i = 0; i < patch->npatch; ++i) {
 88:       SNES snes;
 89:       KSP  subksp;

 91:       SNESCreate(PETSC_COMM_SELF, &snes);
 92:       SNESSetOptionsPrefix(snes, prefix);
 93:       SNESAppendOptionsPrefix(snes, "sub_");
 94:       PetscObjectIncrementTabLevel((PetscObject) snes, (PetscObject) pc, 2);
 95:       SNESGetKSP(snes, &subksp);
 96:       PetscObjectIncrementTabLevel((PetscObject) subksp, (PetscObject) pc, 2);
 97:       PetscLogObjectParent((PetscObject) pc, (PetscObject) snes);
 98:       patch->solver[i] = (PetscObject) snes;
 99:     }

101:     PetscMalloc1(patch->npatch, &patch->patchResidual);
102:     PetscMalloc1(patch->npatch, &patch->patchState);
103:     PetscMalloc1(patch->npatch, &patch->patchStateWithAll);
104:     for (i = 0; i < patch->npatch; ++i) {
105:       VecDuplicate(patch->patchRHS[i], &patch->patchResidual[i]);
106:       VecDuplicate(patch->patchUpdate[i], &patch->patchState[i]);

108:       PetscSectionGetDof(patch->gtolCountsWithAll, i+pStart, &dof);
109:       VecCreateSeq(PETSC_COMM_SELF, dof, &patch->patchStateWithAll[i]);
110:       VecSetUp(patch->patchStateWithAll[i]);
111:     }
112:     VecDuplicate(patch->localUpdate, &patch->localState);
113:   }
114:   for (i = 0; i < patch->npatch; ++i) {
115:     SNES snes = (SNES) patch->solver[i];

117:     SNESSetFunction(snes, patch->patchResidual[i], SNESPatchComputeResidual_Private, pc);
118:     SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc);
119:   }
120:   if (!pc->setupcalled && patch->optionsSet) for (i = 0; i < patch->npatch; ++i) {SNESSetFromOptions((SNES) patch->solver[i]);}
121:   return(0);
122: }

124: static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
125: {
126:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
127:   PetscInt       pStart;

131:   patch->currentPatch = i;
132:   PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0);

134:   /* Scatter the overlapped global state to our patch state vector */
135:   PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
136:   PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchState[i], INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR);
137:   PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchStateWithAll[i], INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL);

139:   /* Set initial guess to be current state*/
140:   VecCopy(patch->patchState[i], patchUpdate);
141:   /* Solve for new state */
142:   SNESSolve((SNES) patch->solver[i], patchRHS, patchUpdate);
143:   /* To compute update, subtract off previous state */
144:   VecAXPY(patchUpdate, -1.0, patch->patchState[i]);

146:   PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0);
147:   return(0);
148: }

150: static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
151: {
152:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
153:   PetscInt       i;

157:   if (patch->solver) {
158:     for (i = 0; i < patch->npatch; ++i) {SNESReset((SNES) patch->solver[i]);}
159:   }

161:   if (patch->patchResidual) {
162:     for (i = 0; i < patch->npatch; ++i) {VecDestroy(&patch->patchResidual[i]);}
163:     PetscFree(patch->patchResidual);
164:   }

166:   if (patch->patchState) {
167:     for (i = 0; i < patch->npatch; ++i) {VecDestroy(&patch->patchState[i]);}
168:     PetscFree(patch->patchState);
169:   }

171:   if (patch->patchStateWithAll) {
172:     for (i = 0; i < patch->npatch; ++i) {VecDestroy(&patch->patchStateWithAll[i]);}
173:     PetscFree(patch->patchStateWithAll);
174:   }

176:   VecDestroy(&patch->localState);
177:   return(0);
178: }

180: static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
181: {
182:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
183:   PetscInt       i;

187:   if (patch->solver) {
188:     for (i = 0; i < patch->npatch; ++i) {SNESDestroy((SNES *) &patch->solver[i]);}
189:     PetscFree(patch->solver);
190:   }
191:   return(0);
192: }

194: static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
195: {
196:   PC_PATCH      *patch = (PC_PATCH *) pc->data;

200:   PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate[i], patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR);
201:   return(0);
202: }

204: static PetscErrorCode SNESSetUp_Patch(SNES snes)
205: {
206:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
207:   DM             dm;
208:   Mat            dummy;
209:   Vec            F;
210:   PetscInt       n, N;

214:   SNESGetDM(snes, &dm);
215:   PCSetDM(patch->pc, dm);
216:   SNESGetFunction(snes, &F, NULL, NULL);
217:   VecGetLocalSize(F, &n);
218:   VecGetSize(F, &N);
219:   MatCreateShell(PetscObjectComm((PetscObject) snes), n, n, N, N, (void *) snes, &dummy);
220:   PCSetOperators(patch->pc, dummy, dummy);
221:   MatDestroy(&dummy);
222:   PCSetUp(patch->pc);
223:   /* allocate workspace */
224:   return(0);
225: }

227: static PetscErrorCode SNESReset_Patch(SNES snes)
228: {
229:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

233:   PCReset(patch->pc);
234:   return(0);
235: }

237: static PetscErrorCode SNESDestroy_Patch(SNES snes)
238: {
239:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

243:   SNESReset_Patch(snes);
244:   PCDestroy(&patch->pc);
245:   PetscFree(snes->data);
246:   return(0);
247: }

249: static PetscErrorCode SNESSetFromOptions_Patch(PetscOptionItems *PetscOptionsObject, SNES snes)
250: {
251:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
252:   const char    *prefix;

256:   PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix);
257:   PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix);
258:   PCSetFromOptions(patch->pc);
259:   return(0);
260: }

262: static PetscErrorCode SNESView_Patch(SNES snes,PetscViewer viewer)
263: {
264:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
265:   PetscBool      iascii;

269:   PetscObjectTypeCompare((PetscObject) viewer, PETSCVIEWERASCII, &iascii);
270:   if (iascii) {
271:     PetscViewerASCIIPrintf(viewer,"SNESPATCH\n");
272:   }
273:   PetscViewerASCIIPushTab(viewer);
274:   PCView(patch->pc, viewer);
275:   PetscViewerASCIIPopTab(viewer);
276:   return(0);
277: }

279: static PetscErrorCode SNESSolve_Patch(SNES snes)
280: {
281:   SNES_Patch *patch = (SNES_Patch *) snes->data;
282:   PC_PATCH   *pcpatch = (PC_PATCH *) patch->pc->data;
283:   SNESLineSearch ls;
284:   Vec rhs, update, state, residual;
285:   const PetscScalar *globalState  = NULL;
286:   PetscScalar       *localState   = NULL;
287:   PetscInt its = 0;
288:   PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;

292:   SNESGetSolution(snes, &state);
293:   SNESGetSolutionUpdate(snes, &update);
294:   SNESGetRhs(snes, &rhs);

296:   SNESGetFunction(snes, &residual, NULL, NULL);
297:   SNESGetLineSearch(snes, &ls);

299:   SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING);
300:   VecSet(update, 0.0);
301:   SNESComputeFunction(snes, state, residual);

303:   VecNorm(state, NORM_2, &xnorm);
304:   VecNorm(residual, NORM_2, &fnorm);
305:   snes->ttol = fnorm*snes->rtol;

307:   if (snes->ops->converged) {
308:     (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
309:   } else {
310:     SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,0);
311:   }
312:   SNESLogConvergenceHistory(snes, fnorm, 0); /* should we count lits from the patches? */
313:   SNESMonitor(snes, its, fnorm);

315:   /* The main solver loop */
316:   for (its = 0; its < snes->max_its; its++) {

318:     SNESSetIterationNumber(snes, its);

320:     /* Scatter state vector to overlapped vector on all patches.
321:        The vector pcpatch->localState is scattered to each patch
322:        in PCApply_PATCH_Nonlinear. */
323:     VecGetArrayRead(state, &globalState);
324:     VecGetArray(pcpatch->localState, &localState);
325:     PetscSFBcastBegin(pcpatch->defaultSF, MPIU_SCALAR, globalState, localState);
326:     PetscSFBcastEnd(pcpatch->defaultSF, MPIU_SCALAR, globalState, localState);
327:     VecRestoreArray(pcpatch->localState, &localState);
328:     VecRestoreArrayRead(state, &globalState);

330:     /* The looping over patches happens here */
331:     PCApply(patch->pc, rhs, update);

333:     /* Apply a line search. This will often be basic with
334:        damping = 1/(max number of patches a dof can be in),
335:        but not always */
336:     VecScale(update, -1.0);
337:     SNESLineSearchApply(ls, state, residual, &fnorm, update);

339:     VecNorm(state, NORM_2, &xnorm);
340:     VecNorm(update, NORM_2, &ynorm);

342:     if (snes->ops->converged) {
343:       (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
344:     } else {
345:       SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,0);
346:     }
347:     SNESLogConvergenceHistory(snes, fnorm, 0); /* FIXME: should we count lits? */
348:     SNESMonitor(snes, its, fnorm);
349:   }

351:   if (its == snes->max_its) { SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT); }
352:   return(0);
353: }

355: /*MC
356:   SNESPATCH - Solve a nonlinear problem by composing together many nonlinear solvers on patches

358:   Level: intermediate

360:   Concepts: composing solvers

362: .seealso:  SNESCreate(), SNESSetType(), SNESType (for list of available types), SNES,
363:            PCPATCH

365:    References:
366: .  1. - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015

368: M*/
369: PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
370: {
372:   SNES_Patch    *patch;
373:   PC_PATCH      *patchpc;

376:   PetscNewLog(snes, &patch);

378:   snes->ops->solve          = SNESSolve_Patch;
379:   snes->ops->setup          = SNESSetUp_Patch;
380:   snes->ops->reset          = SNESReset_Patch;
381:   snes->ops->destroy        = SNESDestroy_Patch;
382:   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
383:   snes->ops->view           = SNESView_Patch;

385:   snes->alwayscomputesfinalresidual = PETSC_FALSE;

387:   snes->data = (void *) patch;
388:   PCCreate(PetscObjectComm((PetscObject) snes), &patch->pc);
389:   PCSetType(patch->pc, PCPATCH);

391:   patchpc = (PC_PATCH*) patch->pc->data;
392:   patchpc->classname = "snes";
393:   patchpc->isNonlinear = PETSC_TRUE;

395:   patchpc->setupsolver   = PCSetUp_PATCH_Nonlinear;
396:   patchpc->applysolver   = PCApply_PATCH_Nonlinear;
397:   patchpc->resetsolver   = PCReset_PATCH_Nonlinear;
398:   patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
399:   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;

401:   return(0);
402: }

404: PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap,
405:                                             const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
406: {
407:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
409:   DM dm;

412:   SNESGetDM(snes, &dm);
413:   if (!dm) SETERRQ(PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES\n");
414:   PCSetDM(patch->pc, dm);
415:   PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes);
416:   return(0);
417: }

419: PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
420: {
421:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

425:   PCPatchSetComputeOperator(patch->pc, func, ctx);
426:   return(0);
427: }

429: PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
430: {
431:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

435:   PCPatchSetComputeFunction(patch->pc, func, ctx);
436:   return(0);
437: }

439: PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
440: {
441:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

445:   PCPatchSetConstructType(patch->pc, ctype, func, ctx);
446:   return(0);
447: }

449: PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
450: {
451:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

455:   PCPatchSetCellNumbering(patch->pc, cellNumbering);
456:   return(0);
457: }