Actual source code: snespatch.c
petsc-3.11.4 2019-09-28
1: /*
2: Defines a SNES that can consist of a collection of SNESes on patches of the domain
3: */
4: #include <petsc/private/snesimpl.h>
5: #include <petsc/private/pcpatchimpl.h> /* We need internal access to PCPatch right now, until that part is moved to Plex */
6: #include <petscsf.h>
8: typedef struct {
9: PC pc; /* The linear patch preconditioner */
10: } SNES_Patch;
12: static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
13: {
14: PC pc = (PC) ctx;
15: PC_PATCH *pcpatch = (PC_PATCH *) pc->data;
16: PetscInt pt, size, i;
17: const PetscInt *indices;
18: const PetscScalar *X;
19: PetscScalar *XWithAll;
24: /* scatter from x to patch->patchStateWithAll[pt] */
25: pt = pcpatch->currentPatch;
26: ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);
28: ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
29: VecGetArrayRead(x, &X);
30: VecGetArray(pcpatch->patchStateWithAll[pt], &XWithAll);
32: for (i = 0; i < size; ++i) {
33: XWithAll[indices[i]] = X[i];
34: }
36: VecRestoreArray(pcpatch->patchStateWithAll[pt], &XWithAll);
37: VecRestoreArrayRead(x, &X);
38: ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
40: PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll[pt], F, pt);
41: return(0);
42: }
44: static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
45: {
46: PC pc = (PC) ctx;
47: PC_PATCH *pcpatch = (PC_PATCH *) pc->data;
48: PetscInt pt, size, i;
49: const PetscInt *indices;
50: const PetscScalar *X;
51: PetscScalar *XWithAll;
55: /* scatter from x to patch->patchStateWithAll[pt] */
56: pt = pcpatch->currentPatch;
57: ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);
59: ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
60: VecGetArrayRead(x, &X);
61: VecGetArray(pcpatch->patchStateWithAll[pt], &XWithAll);
63: for (i = 0; i < size; ++i) {
64: XWithAll[indices[i]] = X[i];
65: }
67: VecRestoreArray(pcpatch->patchStateWithAll[pt], &XWithAll);
68: VecRestoreArrayRead(x, &X);
69: ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
71: PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll[pt], M, pcpatch->currentPatch, PETSC_FALSE);
72: return(0);
73: }
75: static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
76: {
77: PC_PATCH *patch = (PC_PATCH *) pc->data;
78: const char *prefix;
79: PetscInt i, pStart, dof;
83: if (!pc->setupcalled) {
84: PetscMalloc1(patch->npatch, &patch->solver);
85: PCGetOptionsPrefix(pc, &prefix);
86: PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
87: for (i = 0; i < patch->npatch; ++i) {
88: SNES snes;
89: KSP subksp;
91: SNESCreate(PETSC_COMM_SELF, &snes);
92: SNESSetOptionsPrefix(snes, prefix);
93: SNESAppendOptionsPrefix(snes, "sub_");
94: PetscObjectIncrementTabLevel((PetscObject) snes, (PetscObject) pc, 2);
95: SNESGetKSP(snes, &subksp);
96: PetscObjectIncrementTabLevel((PetscObject) subksp, (PetscObject) pc, 2);
97: PetscLogObjectParent((PetscObject) pc, (PetscObject) snes);
98: patch->solver[i] = (PetscObject) snes;
99: }
101: PetscMalloc1(patch->npatch, &patch->patchResidual);
102: PetscMalloc1(patch->npatch, &patch->patchState);
103: PetscMalloc1(patch->npatch, &patch->patchStateWithAll);
104: for (i = 0; i < patch->npatch; ++i) {
105: VecDuplicate(patch->patchRHS[i], &patch->patchResidual[i]);
106: VecDuplicate(patch->patchUpdate[i], &patch->patchState[i]);
108: PetscSectionGetDof(patch->gtolCountsWithAll, i+pStart, &dof);
109: VecCreateSeq(PETSC_COMM_SELF, dof, &patch->patchStateWithAll[i]);
110: VecSetUp(patch->patchStateWithAll[i]);
111: }
112: VecDuplicate(patch->localUpdate, &patch->localState);
113: }
114: for (i = 0; i < patch->npatch; ++i) {
115: SNES snes = (SNES) patch->solver[i];
117: SNESSetFunction(snes, patch->patchResidual[i], SNESPatchComputeResidual_Private, pc);
118: SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc);
119: }
120: if (!pc->setupcalled && patch->optionsSet) for (i = 0; i < patch->npatch; ++i) {SNESSetFromOptions((SNES) patch->solver[i]);}
121: return(0);
122: }
124: static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
125: {
126: PC_PATCH *patch = (PC_PATCH *) pc->data;
127: PetscInt pStart;
131: patch->currentPatch = i;
132: PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0);
134: /* Scatter the overlapped global state to our patch state vector */
135: PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
136: PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchState[i], INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR);
137: PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchStateWithAll[i], INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL);
139: /* Set initial guess to be current state*/
140: VecCopy(patch->patchState[i], patchUpdate);
141: /* Solve for new state */
142: SNESSolve((SNES) patch->solver[i], patchRHS, patchUpdate);
143: /* To compute update, subtract off previous state */
144: VecAXPY(patchUpdate, -1.0, patch->patchState[i]);
146: PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0);
147: return(0);
148: }
150: static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
151: {
152: PC_PATCH *patch = (PC_PATCH *) pc->data;
153: PetscInt i;
157: if (patch->solver) {
158: for (i = 0; i < patch->npatch; ++i) {SNESReset((SNES) patch->solver[i]);}
159: }
161: if (patch->patchResidual) {
162: for (i = 0; i < patch->npatch; ++i) {VecDestroy(&patch->patchResidual[i]);}
163: PetscFree(patch->patchResidual);
164: }
166: if (patch->patchState) {
167: for (i = 0; i < patch->npatch; ++i) {VecDestroy(&patch->patchState[i]);}
168: PetscFree(patch->patchState);
169: }
171: if (patch->patchStateWithAll) {
172: for (i = 0; i < patch->npatch; ++i) {VecDestroy(&patch->patchStateWithAll[i]);}
173: PetscFree(patch->patchStateWithAll);
174: }
176: VecDestroy(&patch->localState);
177: return(0);
178: }
180: static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
181: {
182: PC_PATCH *patch = (PC_PATCH *) pc->data;
183: PetscInt i;
187: if (patch->solver) {
188: for (i = 0; i < patch->npatch; ++i) {SNESDestroy((SNES *) &patch->solver[i]);}
189: PetscFree(patch->solver);
190: }
191: return(0);
192: }
194: static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
195: {
196: PC_PATCH *patch = (PC_PATCH *) pc->data;
200: PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate[i], patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR);
201: return(0);
202: }
204: static PetscErrorCode SNESSetUp_Patch(SNES snes)
205: {
206: SNES_Patch *patch = (SNES_Patch *) snes->data;
207: DM dm;
208: Mat dummy;
209: Vec F;
210: PetscInt n, N;
214: SNESGetDM(snes, &dm);
215: PCSetDM(patch->pc, dm);
216: SNESGetFunction(snes, &F, NULL, NULL);
217: VecGetLocalSize(F, &n);
218: VecGetSize(F, &N);
219: MatCreateShell(PetscObjectComm((PetscObject) snes), n, n, N, N, (void *) snes, &dummy);
220: PCSetOperators(patch->pc, dummy, dummy);
221: MatDestroy(&dummy);
222: PCSetUp(patch->pc);
223: /* allocate workspace */
224: return(0);
225: }
227: static PetscErrorCode SNESReset_Patch(SNES snes)
228: {
229: SNES_Patch *patch = (SNES_Patch *) snes->data;
233: PCReset(patch->pc);
234: return(0);
235: }
237: static PetscErrorCode SNESDestroy_Patch(SNES snes)
238: {
239: SNES_Patch *patch = (SNES_Patch *) snes->data;
243: SNESReset_Patch(snes);
244: PCDestroy(&patch->pc);
245: PetscFree(snes->data);
246: return(0);
247: }
249: static PetscErrorCode SNESSetFromOptions_Patch(PetscOptionItems *PetscOptionsObject, SNES snes)
250: {
251: SNES_Patch *patch = (SNES_Patch *) snes->data;
252: const char *prefix;
256: PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix);
257: PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix);
258: PCSetFromOptions(patch->pc);
259: return(0);
260: }
262: static PetscErrorCode SNESView_Patch(SNES snes,PetscViewer viewer)
263: {
264: SNES_Patch *patch = (SNES_Patch *) snes->data;
265: PetscBool iascii;
269: PetscObjectTypeCompare((PetscObject) viewer, PETSCVIEWERASCII, &iascii);
270: if (iascii) {
271: PetscViewerASCIIPrintf(viewer,"SNESPATCH\n");
272: }
273: PetscViewerASCIIPushTab(viewer);
274: PCView(patch->pc, viewer);
275: PetscViewerASCIIPopTab(viewer);
276: return(0);
277: }
279: static PetscErrorCode SNESSolve_Patch(SNES snes)
280: {
281: SNES_Patch *patch = (SNES_Patch *) snes->data;
282: PC_PATCH *pcpatch = (PC_PATCH *) patch->pc->data;
283: SNESLineSearch ls;
284: Vec rhs, update, state, residual;
285: const PetscScalar *globalState = NULL;
286: PetscScalar *localState = NULL;
287: PetscInt its = 0;
288: PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
292: SNESGetSolution(snes, &state);
293: SNESGetSolutionUpdate(snes, &update);
294: SNESGetRhs(snes, &rhs);
296: SNESGetFunction(snes, &residual, NULL, NULL);
297: SNESGetLineSearch(snes, &ls);
299: SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING);
300: VecSet(update, 0.0);
301: SNESComputeFunction(snes, state, residual);
303: VecNorm(state, NORM_2, &xnorm);
304: VecNorm(residual, NORM_2, &fnorm);
305: snes->ttol = fnorm*snes->rtol;
307: if (snes->ops->converged) {
308: (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
309: } else {
310: SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,0);
311: }
312: SNESLogConvergenceHistory(snes, fnorm, 0); /* should we count lits from the patches? */
313: SNESMonitor(snes, its, fnorm);
315: /* The main solver loop */
316: for (its = 0; its < snes->max_its; its++) {
318: SNESSetIterationNumber(snes, its);
320: /* Scatter state vector to overlapped vector on all patches.
321: The vector pcpatch->localState is scattered to each patch
322: in PCApply_PATCH_Nonlinear. */
323: VecGetArrayRead(state, &globalState);
324: VecGetArray(pcpatch->localState, &localState);
325: PetscSFBcastBegin(pcpatch->defaultSF, MPIU_SCALAR, globalState, localState);
326: PetscSFBcastEnd(pcpatch->defaultSF, MPIU_SCALAR, globalState, localState);
327: VecRestoreArray(pcpatch->localState, &localState);
328: VecRestoreArrayRead(state, &globalState);
330: /* The looping over patches happens here */
331: PCApply(patch->pc, rhs, update);
333: /* Apply a line search. This will often be basic with
334: damping = 1/(max number of patches a dof can be in),
335: but not always */
336: VecScale(update, -1.0);
337: SNESLineSearchApply(ls, state, residual, &fnorm, update);
339: VecNorm(state, NORM_2, &xnorm);
340: VecNorm(update, NORM_2, &ynorm);
342: if (snes->ops->converged) {
343: (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
344: } else {
345: SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,0);
346: }
347: SNESLogConvergenceHistory(snes, fnorm, 0); /* FIXME: should we count lits? */
348: SNESMonitor(snes, its, fnorm);
349: }
351: if (its == snes->max_its) { SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT); }
352: return(0);
353: }
355: /*MC
356: SNESPATCH - Solve a nonlinear problem by composing together many nonlinear solvers on patches
358: Level: intermediate
360: Concepts: composing solvers
362: .seealso: SNESCreate(), SNESSetType(), SNESType (for list of available types), SNES,
363: PCPATCH
365: References:
366: . 1. - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015
368: M*/
369: PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
370: {
372: SNES_Patch *patch;
373: PC_PATCH *patchpc;
376: PetscNewLog(snes, &patch);
378: snes->ops->solve = SNESSolve_Patch;
379: snes->ops->setup = SNESSetUp_Patch;
380: snes->ops->reset = SNESReset_Patch;
381: snes->ops->destroy = SNESDestroy_Patch;
382: snes->ops->setfromoptions = SNESSetFromOptions_Patch;
383: snes->ops->view = SNESView_Patch;
385: snes->alwayscomputesfinalresidual = PETSC_FALSE;
387: snes->data = (void *) patch;
388: PCCreate(PetscObjectComm((PetscObject) snes), &patch->pc);
389: PCSetType(patch->pc, PCPATCH);
391: patchpc = (PC_PATCH*) patch->pc->data;
392: patchpc->classname = "snes";
393: patchpc->isNonlinear = PETSC_TRUE;
395: patchpc->setupsolver = PCSetUp_PATCH_Nonlinear;
396: patchpc->applysolver = PCApply_PATCH_Nonlinear;
397: patchpc->resetsolver = PCReset_PATCH_Nonlinear;
398: patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
399: patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
401: return(0);
402: }
404: PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap,
405: const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
406: {
407: SNES_Patch *patch = (SNES_Patch *) snes->data;
409: DM dm;
412: SNESGetDM(snes, &dm);
413: if (!dm) SETERRQ(PetscObjectComm((PetscObject)snes), PETSC_ERR_ARG_WRONGSTATE, "DM not yet set on patch SNES\n");
414: PCSetDM(patch->pc, dm);
415: PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes);
416: return(0);
417: }
419: PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
420: {
421: SNES_Patch *patch = (SNES_Patch *) snes->data;
425: PCPatchSetComputeOperator(patch->pc, func, ctx);
426: return(0);
427: }
429: PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
430: {
431: SNES_Patch *patch = (SNES_Patch *) snes->data;
435: PCPatchSetComputeFunction(patch->pc, func, ctx);
436: return(0);
437: }
439: PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
440: {
441: SNES_Patch *patch = (SNES_Patch *) snes->data;
445: PCPatchSetConstructType(patch->pc, ctype, func, ctx);
446: return(0);
447: }
449: PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
450: {
451: SNES_Patch *patch = (SNES_Patch *) snes->data;
455: PCPatchSetCellNumbering(patch->pc, cellNumbering);
456: return(0);
457: }