Actual source code: snespatch.c

  1: /*
  2:       Defines a SNES that can consist of a collection of SNESes on patches of the domain
  3: */
  4: #include <petsc/private/vecimpl.h>
  5: #include <petsc/private/snesimpl.h>
  6: #include <petsc/private/pcpatchimpl.h>
  7: #include <petscsf.h>
  8: #include <petscsection.h>

 10: typedef struct {
 11:   PC pc; /* The linear patch preconditioner */
 12: } SNES_Patch;

 14: static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
 15: {
 16:   PC                pc      = (PC) ctx;
 17:   PC_PATCH          *pcpatch = (PC_PATCH *) pc->data;
 18:   PetscInt          pt, size, i;
 19:   const PetscInt    *indices;
 20:   const PetscScalar *X;
 21:   PetscScalar       *XWithAll;


 24:   /* scatter from x to patch->patchStateWithAll[pt] */
 25:   pt = pcpatch->currentPatch;
 26:   ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);

 28:   ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
 29:   VecGetArrayRead(x, &X);
 30:   VecGetArray(pcpatch->patchStateWithAll, &XWithAll);

 32:   for (i = 0; i < size; ++i) {
 33:     XWithAll[indices[i]] = X[i];
 34:   }

 36:   VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll);
 37:   VecRestoreArrayRead(x, &X);
 38:   ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);

 40:   PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt);
 41:   return 0;
 42: }

 44: static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
 45: {
 46:   PC                pc      = (PC) ctx;
 47:   PC_PATCH          *pcpatch = (PC_PATCH *) pc->data;
 48:   PetscInt          pt, size, i;
 49:   const PetscInt    *indices;
 50:   const PetscScalar *X;
 51:   PetscScalar       *XWithAll;

 53:   /* scatter from x to patch->patchStateWithAll[pt] */
 54:   pt = pcpatch->currentPatch;
 55:   ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);

 57:   ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
 58:   VecGetArrayRead(x, &X);
 59:   VecGetArray(pcpatch->patchStateWithAll, &XWithAll);

 61:   for (i = 0; i < size; ++i) {
 62:     XWithAll[indices[i]] = X[i];
 63:   }

 65:   VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll);
 66:   VecRestoreArrayRead(x, &X);
 67:   ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);

 69:   PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE);
 70:   return 0;
 71: }

 73: static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
 74: {
 75:   PC_PATCH       *patch = (PC_PATCH *) pc->data;
 76:   const char     *prefix;
 77:   PetscInt       i, pStart, dof, maxDof = -1;

 79:   if (!pc->setupcalled) {
 80:     PetscMalloc1(patch->npatch, &patch->solver);
 81:     PCGetOptionsPrefix(pc, &prefix);
 82:     PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
 83:     for (i = 0; i < patch->npatch; ++i) {
 84:       SNES snes;

 86:       SNESCreate(PETSC_COMM_SELF, &snes);
 87:       SNESSetOptionsPrefix(snes, prefix);
 88:       SNESAppendOptionsPrefix(snes, "sub_");
 89:       PetscObjectIncrementTabLevel((PetscObject) snes, (PetscObject) pc, 2);
 90:       PetscLogObjectParent((PetscObject) pc, (PetscObject) snes);
 91:       patch->solver[i] = (PetscObject) snes;

 93:       PetscSectionGetDof(patch->gtolCountsWithAll, i+pStart, &dof);
 94:       maxDof = PetscMax(maxDof, dof);
 95:     }
 96:     VecDuplicate(patch->localUpdate, &patch->localState);
 97:     VecDuplicate(patch->patchRHS, &patch->patchResidual);
 98:     VecDuplicate(patch->patchUpdate, &patch->patchState);

100:     VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll);
101:     VecSetUp(patch->patchStateWithAll);
102:   }
103:   for (i = 0; i < patch->npatch; ++i) {
104:     SNES snes = (SNES) patch->solver[i];

106:     SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc);
107:     SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc);
108:   }
109:   if (!pc->setupcalled && patch->optionsSet) for (i = 0; i < patch->npatch; ++i) SNESSetFromOptions((SNES) patch->solver[i]);
110:   return 0;
111: }

113: static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
114: {
115:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
116:   PetscInt       pStart, n;

118:   patch->currentPatch = i;
119:   PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0);

121:   /* Scatter the overlapped global state to our patch state vector */
122:   PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
123:   PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR);
124:   PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL);

126:   MatGetLocalSize(patch->mat[i], NULL, &n);
127:   patch->patchState->map->n = n;
128:   patch->patchState->map->N = n;
129:   patchUpdate->map->n = n;
130:   patchUpdate->map->N = n;
131:   patchRHS->map->n = n;
132:   patchRHS->map->N = n;
133:   /* Set initial guess to be current state*/
134:   VecCopy(patch->patchState, patchUpdate);
135:   /* Solve for new state */
136:   SNESSolve((SNES) patch->solver[i], patchRHS, patchUpdate);
137:   /* To compute update, subtract off previous state */
138:   VecAXPY(patchUpdate, -1.0, patch->patchState);

140:   PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0);
141:   return 0;
142: }

144: static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
145: {
146:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
147:   PetscInt       i;

149:   if (patch->solver) {
150:     for (i = 0; i < patch->npatch; ++i) SNESReset((SNES) patch->solver[i]);
151:   }

153:   VecDestroy(&patch->patchResidual);
154:   VecDestroy(&patch->patchState);
155:   VecDestroy(&patch->patchStateWithAll);

157:   VecDestroy(&patch->localState);
158:   return 0;
159: }

161: static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
162: {
163:   PC_PATCH      *patch = (PC_PATCH *) pc->data;
164:   PetscInt       i;

166:   if (patch->solver) {
167:     for (i = 0; i < patch->npatch; ++i) SNESDestroy((SNES *) &patch->solver[i]);
168:     PetscFree(patch->solver);
169:   }
170:   return 0;
171: }

173: static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
174: {
175:   PC_PATCH      *patch = (PC_PATCH *) pc->data;

177:   PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR);
178:   return 0;
179: }

181: static PetscErrorCode SNESSetUp_Patch(SNES snes)
182: {
183:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
184:   DM             dm;
185:   Mat            dummy;
186:   Vec            F;
187:   PetscInt       n, N;

189:   SNESGetDM(snes, &dm);
190:   PCSetDM(patch->pc, dm);
191:   SNESGetFunction(snes, &F, NULL, NULL);
192:   VecGetLocalSize(F, &n);
193:   VecGetSize(F, &N);
194:   MatCreateShell(PetscObjectComm((PetscObject) snes), n, n, N, N, (void *) snes, &dummy);
195:   PCSetOperators(patch->pc, dummy, dummy);
196:   MatDestroy(&dummy);
197:   PCSetUp(patch->pc);
198:   /* allocate workspace */
199:   return 0;
200: }

202: static PetscErrorCode SNESReset_Patch(SNES snes)
203: {
204:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

206:   PCReset(patch->pc);
207:   return 0;
208: }

210: static PetscErrorCode SNESDestroy_Patch(SNES snes)
211: {
212:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

214:   SNESReset_Patch(snes);
215:   PCDestroy(&patch->pc);
216:   PetscFree(snes->data);
217:   return 0;
218: }

220: static PetscErrorCode SNESSetFromOptions_Patch(PetscOptionItems *PetscOptionsObject, SNES snes)
221: {
222:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
223:   const char    *prefix;

225:   PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix);
226:   PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix);
227:   PCSetFromOptions(patch->pc);
228:   return 0;
229: }

231: static PetscErrorCode SNESView_Patch(SNES snes,PetscViewer viewer)
232: {
233:   SNES_Patch    *patch = (SNES_Patch *) snes->data;
234:   PetscBool      iascii;

236:   PetscObjectTypeCompare((PetscObject) viewer, PETSCVIEWERASCII, &iascii);
237:   if (iascii) {
238:     PetscViewerASCIIPrintf(viewer,"SNESPATCH\n");
239:   }
240:   PetscViewerASCIIPushTab(viewer);
241:   PCView(patch->pc, viewer);
242:   PetscViewerASCIIPopTab(viewer);
243:   return 0;
244: }

246: static PetscErrorCode SNESSolve_Patch(SNES snes)
247: {
248:   SNES_Patch        *patch = (SNES_Patch *) snes->data;
249:   PC_PATCH          *pcpatch = (PC_PATCH *) patch->pc->data;
250:   SNESLineSearch    ls;
251:   Vec               rhs, update, state, residual;
252:   const PetscScalar *globalState  = NULL;
253:   PetscScalar       *localState   = NULL;
254:   PetscInt          its = 0;
255:   PetscReal         xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;

257:   SNESGetSolution(snes, &state);
258:   SNESGetSolutionUpdate(snes, &update);
259:   SNESGetRhs(snes, &rhs);

261:   SNESGetFunction(snes, &residual, NULL, NULL);
262:   SNESGetLineSearch(snes, &ls);

264:   SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING);
265:   VecSet(update, 0.0);
266:   SNESComputeFunction(snes, state, residual);

268:   VecNorm(state, NORM_2, &xnorm);
269:   VecNorm(residual, NORM_2, &fnorm);
270:   snes->ttol = fnorm*snes->rtol;

272:   if (snes->ops->converged) {
273:     (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
274:   } else {
275:     SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,NULL);
276:   }
277:   SNESLogConvergenceHistory(snes, fnorm, 0); /* should we count lits from the patches? */
278:   SNESMonitor(snes, its, fnorm);

280:   /* The main solver loop */
281:   for (its = 0; its < snes->max_its; its++) {

283:     SNESSetIterationNumber(snes, its);

285:     /* Scatter state vector to overlapped vector on all patches.
286:        The vector pcpatch->localState is scattered to each patch
287:        in PCApply_PATCH_Nonlinear. */
288:     VecGetArrayRead(state, &globalState);
289:     VecGetArray(pcpatch->localState, &localState);
290:     PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState,MPI_REPLACE);
291:     PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState,MPI_REPLACE);
292:     VecRestoreArray(pcpatch->localState, &localState);
293:     VecRestoreArrayRead(state, &globalState);

295:     /* The looping over patches happens here */
296:     PCApply(patch->pc, rhs, update);

298:     /* Apply a line search. This will often be basic with
299:        damping = 1/(max number of patches a dof can be in),
300:        but not always */
301:     VecScale(update, -1.0);
302:     SNESLineSearchApply(ls, state, residual, &fnorm, update);

304:     VecNorm(state, NORM_2, &xnorm);
305:     VecNorm(update, NORM_2, &ynorm);

307:     if (snes->ops->converged) {
308:       (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
309:     } else {
310:       SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,NULL);
311:     }
312:     SNESLogConvergenceHistory(snes, fnorm, 0); /* FIXME: should we count lits? */
313:     SNESMonitor(snes, its, fnorm);
314:   }

316:   if (its == snes->max_its) SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT);
317:   return 0;
318: }

320: /*MC
321:   SNESPATCH - Solve a nonlinear problem by composing together many nonlinear solvers on patches

323:   Level: intermediate

325: .seealso:  SNESCreate(), SNESSetType(), SNESType (for list of available types), SNES,
326:            PCPATCH

328:    References:
329: .  * - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015

331: M*/
332: PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
333: {
334:   SNES_Patch     *patch;
335:   PC_PATCH       *patchpc;
336:   SNESLineSearch linesearch;

338:   PetscNewLog(snes, &patch);

340:   snes->ops->solve          = SNESSolve_Patch;
341:   snes->ops->setup          = SNESSetUp_Patch;
342:   snes->ops->reset          = SNESReset_Patch;
343:   snes->ops->destroy        = SNESDestroy_Patch;
344:   snes->ops->setfromoptions = SNESSetFromOptions_Patch;
345:   snes->ops->view           = SNESView_Patch;

347:   SNESGetLineSearch(snes,&linesearch);
348:   if (!((PetscObject)linesearch)->type_name) {
349:     SNESLineSearchSetType(linesearch,SNESLINESEARCHBASIC);
350:   }
351:   snes->usesksp        = PETSC_FALSE;

353:   snes->alwayscomputesfinalresidual = PETSC_FALSE;

355:   snes->data = (void *) patch;
356:   PCCreate(PetscObjectComm((PetscObject) snes), &patch->pc);
357:   PCSetType(patch->pc, PCPATCH);

359:   patchpc = (PC_PATCH*) patch->pc->data;
360:   patchpc->classname = "snes";
361:   patchpc->isNonlinear = PETSC_TRUE;

363:   patchpc->setupsolver   = PCSetUp_PATCH_Nonlinear;
364:   patchpc->applysolver   = PCApply_PATCH_Nonlinear;
365:   patchpc->resetsolver   = PCReset_PATCH_Nonlinear;
366:   patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
367:   patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;

369:   return 0;
370: }

372: PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap,
373:                                             const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
374: {
375:   SNES_Patch     *patch = (SNES_Patch *) snes->data;
376:   DM             dm;

378:   SNESGetDM(snes, &dm);
380:   PCSetDM(patch->pc, dm);
381:   PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes);
382:   return 0;
383: }

385: PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
386: {
387:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

389:   PCPatchSetComputeOperator(patch->pc, func, ctx);
390:   return 0;
391: }

393: PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
394: {
395:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

397:   PCPatchSetComputeFunction(patch->pc, func, ctx);
398:   return 0;
399: }

401: PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
402: {
403:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

405:   PCPatchSetConstructType(patch->pc, ctype, func, ctx);
406:   return 0;
407: }

409: PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
410: {
411:   SNES_Patch    *patch = (SNES_Patch *) snes->data;

413:   PCPatchSetCellNumbering(patch->pc, cellNumbering);
414:   return 0;
415: }