Actual source code: snespatch.c
1: /*
2: Defines a SNES that can consist of a collection of SNESes on patches of the domain
3: */
4: #include <petsc/private/vecimpl.h>
5: #include <petsc/private/snesimpl.h>
6: #include <petsc/private/pcpatchimpl.h>
7: #include <petscsf.h>
8: #include <petscsection.h>
10: typedef struct {
11: PC pc; /* The linear patch preconditioner */
12: } SNES_Patch;
14: static PetscErrorCode SNESPatchComputeResidual_Private(SNES snes, Vec x, Vec F, void *ctx)
15: {
16: PC pc = (PC) ctx;
17: PC_PATCH *pcpatch = (PC_PATCH *) pc->data;
18: PetscInt pt, size, i;
19: const PetscInt *indices;
20: const PetscScalar *X;
21: PetscScalar *XWithAll;
24: /* scatter from x to patch->patchStateWithAll[pt] */
25: pt = pcpatch->currentPatch;
26: ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);
28: ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
29: VecGetArrayRead(x, &X);
30: VecGetArray(pcpatch->patchStateWithAll, &XWithAll);
32: for (i = 0; i < size; ++i) {
33: XWithAll[indices[i]] = X[i];
34: }
36: VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll);
37: VecRestoreArrayRead(x, &X);
38: ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
40: PCPatchComputeFunction_Internal(pc, pcpatch->patchStateWithAll, F, pt);
41: return 0;
42: }
44: static PetscErrorCode SNESPatchComputeJacobian_Private(SNES snes, Vec x, Mat J, Mat M, void *ctx)
45: {
46: PC pc = (PC) ctx;
47: PC_PATCH *pcpatch = (PC_PATCH *) pc->data;
48: PetscInt pt, size, i;
49: const PetscInt *indices;
50: const PetscScalar *X;
51: PetscScalar *XWithAll;
53: /* scatter from x to patch->patchStateWithAll[pt] */
54: pt = pcpatch->currentPatch;
55: ISGetSize(pcpatch->dofMappingWithoutToWithAll[pt], &size);
57: ISGetIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
58: VecGetArrayRead(x, &X);
59: VecGetArray(pcpatch->patchStateWithAll, &XWithAll);
61: for (i = 0; i < size; ++i) {
62: XWithAll[indices[i]] = X[i];
63: }
65: VecRestoreArray(pcpatch->patchStateWithAll, &XWithAll);
66: VecRestoreArrayRead(x, &X);
67: ISRestoreIndices(pcpatch->dofMappingWithoutToWithAll[pt], &indices);
69: PCPatchComputeOperator_Internal(pc, pcpatch->patchStateWithAll, M, pcpatch->currentPatch, PETSC_FALSE);
70: return 0;
71: }
73: static PetscErrorCode PCSetUp_PATCH_Nonlinear(PC pc)
74: {
75: PC_PATCH *patch = (PC_PATCH *) pc->data;
76: const char *prefix;
77: PetscInt i, pStart, dof, maxDof = -1;
79: if (!pc->setupcalled) {
80: PetscMalloc1(patch->npatch, &patch->solver);
81: PCGetOptionsPrefix(pc, &prefix);
82: PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
83: for (i = 0; i < patch->npatch; ++i) {
84: SNES snes;
86: SNESCreate(PETSC_COMM_SELF, &snes);
87: SNESSetOptionsPrefix(snes, prefix);
88: SNESAppendOptionsPrefix(snes, "sub_");
89: PetscObjectIncrementTabLevel((PetscObject) snes, (PetscObject) pc, 2);
90: PetscLogObjectParent((PetscObject) pc, (PetscObject) snes);
91: patch->solver[i] = (PetscObject) snes;
93: PetscSectionGetDof(patch->gtolCountsWithAll, i+pStart, &dof);
94: maxDof = PetscMax(maxDof, dof);
95: }
96: VecDuplicate(patch->localUpdate, &patch->localState);
97: VecDuplicate(patch->patchRHS, &patch->patchResidual);
98: VecDuplicate(patch->patchUpdate, &patch->patchState);
100: VecCreateSeq(PETSC_COMM_SELF, maxDof, &patch->patchStateWithAll);
101: VecSetUp(patch->patchStateWithAll);
102: }
103: for (i = 0; i < patch->npatch; ++i) {
104: SNES snes = (SNES) patch->solver[i];
106: SNESSetFunction(snes, patch->patchResidual, SNESPatchComputeResidual_Private, pc);
107: SNESSetJacobian(snes, patch->mat[i], patch->mat[i], SNESPatchComputeJacobian_Private, pc);
108: }
109: if (!pc->setupcalled && patch->optionsSet) for (i = 0; i < patch->npatch; ++i) SNESSetFromOptions((SNES) patch->solver[i]);
110: return 0;
111: }
113: static PetscErrorCode PCApply_PATCH_Nonlinear(PC pc, PetscInt i, Vec patchRHS, Vec patchUpdate)
114: {
115: PC_PATCH *patch = (PC_PATCH *) pc->data;
116: PetscInt pStart, n;
118: patch->currentPatch = i;
119: PetscLogEventBegin(PC_Patch_Solve, pc, 0, 0, 0);
121: /* Scatter the overlapped global state to our patch state vector */
122: PetscSectionGetChart(patch->gtolCounts, &pStart, NULL);
123: PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchState, INSERT_VALUES, SCATTER_FORWARD, SCATTER_INTERIOR);
124: PCPatch_ScatterLocal_Private(pc, i+pStart, patch->localState, patch->patchStateWithAll, INSERT_VALUES, SCATTER_FORWARD, SCATTER_WITHALL);
126: MatGetLocalSize(patch->mat[i], NULL, &n);
127: patch->patchState->map->n = n;
128: patch->patchState->map->N = n;
129: patchUpdate->map->n = n;
130: patchUpdate->map->N = n;
131: patchRHS->map->n = n;
132: patchRHS->map->N = n;
133: /* Set initial guess to be current state*/
134: VecCopy(patch->patchState, patchUpdate);
135: /* Solve for new state */
136: SNESSolve((SNES) patch->solver[i], patchRHS, patchUpdate);
137: /* To compute update, subtract off previous state */
138: VecAXPY(patchUpdate, -1.0, patch->patchState);
140: PetscLogEventEnd(PC_Patch_Solve, pc, 0, 0, 0);
141: return 0;
142: }
144: static PetscErrorCode PCReset_PATCH_Nonlinear(PC pc)
145: {
146: PC_PATCH *patch = (PC_PATCH *) pc->data;
147: PetscInt i;
149: if (patch->solver) {
150: for (i = 0; i < patch->npatch; ++i) SNESReset((SNES) patch->solver[i]);
151: }
153: VecDestroy(&patch->patchResidual);
154: VecDestroy(&patch->patchState);
155: VecDestroy(&patch->patchStateWithAll);
157: VecDestroy(&patch->localState);
158: return 0;
159: }
161: static PetscErrorCode PCDestroy_PATCH_Nonlinear(PC pc)
162: {
163: PC_PATCH *patch = (PC_PATCH *) pc->data;
164: PetscInt i;
166: if (patch->solver) {
167: for (i = 0; i < patch->npatch; ++i) SNESDestroy((SNES *) &patch->solver[i]);
168: PetscFree(patch->solver);
169: }
170: return 0;
171: }
173: static PetscErrorCode PCUpdateMultiplicative_PATCH_Nonlinear(PC pc, PetscInt i, PetscInt pStart)
174: {
175: PC_PATCH *patch = (PC_PATCH *) pc->data;
177: PCPatch_ScatterLocal_Private(pc, i + pStart, patch->patchUpdate, patch->localState, ADD_VALUES, SCATTER_REVERSE, SCATTER_INTERIOR);
178: return 0;
179: }
181: static PetscErrorCode SNESSetUp_Patch(SNES snes)
182: {
183: SNES_Patch *patch = (SNES_Patch *) snes->data;
184: DM dm;
185: Mat dummy;
186: Vec F;
187: PetscInt n, N;
189: SNESGetDM(snes, &dm);
190: PCSetDM(patch->pc, dm);
191: SNESGetFunction(snes, &F, NULL, NULL);
192: VecGetLocalSize(F, &n);
193: VecGetSize(F, &N);
194: MatCreateShell(PetscObjectComm((PetscObject) snes), n, n, N, N, (void *) snes, &dummy);
195: PCSetOperators(patch->pc, dummy, dummy);
196: MatDestroy(&dummy);
197: PCSetUp(patch->pc);
198: /* allocate workspace */
199: return 0;
200: }
202: static PetscErrorCode SNESReset_Patch(SNES snes)
203: {
204: SNES_Patch *patch = (SNES_Patch *) snes->data;
206: PCReset(patch->pc);
207: return 0;
208: }
210: static PetscErrorCode SNESDestroy_Patch(SNES snes)
211: {
212: SNES_Patch *patch = (SNES_Patch *) snes->data;
214: SNESReset_Patch(snes);
215: PCDestroy(&patch->pc);
216: PetscFree(snes->data);
217: return 0;
218: }
220: static PetscErrorCode SNESSetFromOptions_Patch(PetscOptionItems *PetscOptionsObject, SNES snes)
221: {
222: SNES_Patch *patch = (SNES_Patch *) snes->data;
223: const char *prefix;
225: PetscObjectGetOptionsPrefix((PetscObject)snes, &prefix);
226: PetscObjectSetOptionsPrefix((PetscObject)patch->pc, prefix);
227: PCSetFromOptions(patch->pc);
228: return 0;
229: }
231: static PetscErrorCode SNESView_Patch(SNES snes,PetscViewer viewer)
232: {
233: SNES_Patch *patch = (SNES_Patch *) snes->data;
234: PetscBool iascii;
236: PetscObjectTypeCompare((PetscObject) viewer, PETSCVIEWERASCII, &iascii);
237: if (iascii) {
238: PetscViewerASCIIPrintf(viewer,"SNESPATCH\n");
239: }
240: PetscViewerASCIIPushTab(viewer);
241: PCView(patch->pc, viewer);
242: PetscViewerASCIIPopTab(viewer);
243: return 0;
244: }
246: static PetscErrorCode SNESSolve_Patch(SNES snes)
247: {
248: SNES_Patch *patch = (SNES_Patch *) snes->data;
249: PC_PATCH *pcpatch = (PC_PATCH *) patch->pc->data;
250: SNESLineSearch ls;
251: Vec rhs, update, state, residual;
252: const PetscScalar *globalState = NULL;
253: PetscScalar *localState = NULL;
254: PetscInt its = 0;
255: PetscReal xnorm = 0.0, ynorm = 0.0, fnorm = 0.0;
257: SNESGetSolution(snes, &state);
258: SNESGetSolutionUpdate(snes, &update);
259: SNESGetRhs(snes, &rhs);
261: SNESGetFunction(snes, &residual, NULL, NULL);
262: SNESGetLineSearch(snes, &ls);
264: SNESSetConvergedReason(snes, SNES_CONVERGED_ITERATING);
265: VecSet(update, 0.0);
266: SNESComputeFunction(snes, state, residual);
268: VecNorm(state, NORM_2, &xnorm);
269: VecNorm(residual, NORM_2, &fnorm);
270: snes->ttol = fnorm*snes->rtol;
272: if (snes->ops->converged) {
273: (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
274: } else {
275: SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,NULL);
276: }
277: SNESLogConvergenceHistory(snes, fnorm, 0); /* should we count lits from the patches? */
278: SNESMonitor(snes, its, fnorm);
280: /* The main solver loop */
281: for (its = 0; its < snes->max_its; its++) {
283: SNESSetIterationNumber(snes, its);
285: /* Scatter state vector to overlapped vector on all patches.
286: The vector pcpatch->localState is scattered to each patch
287: in PCApply_PATCH_Nonlinear. */
288: VecGetArrayRead(state, &globalState);
289: VecGetArray(pcpatch->localState, &localState);
290: PetscSFBcastBegin(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState,MPI_REPLACE);
291: PetscSFBcastEnd(pcpatch->sectionSF, MPIU_SCALAR, globalState, localState,MPI_REPLACE);
292: VecRestoreArray(pcpatch->localState, &localState);
293: VecRestoreArrayRead(state, &globalState);
295: /* The looping over patches happens here */
296: PCApply(patch->pc, rhs, update);
298: /* Apply a line search. This will often be basic with
299: damping = 1/(max number of patches a dof can be in),
300: but not always */
301: VecScale(update, -1.0);
302: SNESLineSearchApply(ls, state, residual, &fnorm, update);
304: VecNorm(state, NORM_2, &xnorm);
305: VecNorm(update, NORM_2, &ynorm);
307: if (snes->ops->converged) {
308: (*snes->ops->converged)(snes,its,xnorm,ynorm,fnorm,&snes->reason,snes->cnvP);
309: } else {
310: SNESConvergedSkip(snes,its,xnorm,ynorm,fnorm,&snes->reason,NULL);
311: }
312: SNESLogConvergenceHistory(snes, fnorm, 0); /* FIXME: should we count lits? */
313: SNESMonitor(snes, its, fnorm);
314: }
316: if (its == snes->max_its) SNESSetConvergedReason(snes, SNES_DIVERGED_MAX_IT);
317: return 0;
318: }
320: /*MC
321: SNESPATCH - Solve a nonlinear problem by composing together many nonlinear solvers on patches
323: Level: intermediate
325: .seealso: SNESCreate(), SNESSetType(), SNESType (for list of available types), SNES,
326: PCPATCH
328: References:
329: . * - Peter R. Brune, Matthew G. Knepley, Barry F. Smith, and Xuemin Tu, "Composing Scalable Nonlinear Algebraic Solvers", SIAM Review, 57(4), 2015
331: M*/
332: PETSC_EXTERN PetscErrorCode SNESCreate_Patch(SNES snes)
333: {
334: SNES_Patch *patch;
335: PC_PATCH *patchpc;
336: SNESLineSearch linesearch;
338: PetscNewLog(snes, &patch);
340: snes->ops->solve = SNESSolve_Patch;
341: snes->ops->setup = SNESSetUp_Patch;
342: snes->ops->reset = SNESReset_Patch;
343: snes->ops->destroy = SNESDestroy_Patch;
344: snes->ops->setfromoptions = SNESSetFromOptions_Patch;
345: snes->ops->view = SNESView_Patch;
347: SNESGetLineSearch(snes,&linesearch);
348: if (!((PetscObject)linesearch)->type_name) {
349: SNESLineSearchSetType(linesearch,SNESLINESEARCHBASIC);
350: }
351: snes->usesksp = PETSC_FALSE;
353: snes->alwayscomputesfinalresidual = PETSC_FALSE;
355: snes->data = (void *) patch;
356: PCCreate(PetscObjectComm((PetscObject) snes), &patch->pc);
357: PCSetType(patch->pc, PCPATCH);
359: patchpc = (PC_PATCH*) patch->pc->data;
360: patchpc->classname = "snes";
361: patchpc->isNonlinear = PETSC_TRUE;
363: patchpc->setupsolver = PCSetUp_PATCH_Nonlinear;
364: patchpc->applysolver = PCApply_PATCH_Nonlinear;
365: patchpc->resetsolver = PCReset_PATCH_Nonlinear;
366: patchpc->destroysolver = PCDestroy_PATCH_Nonlinear;
367: patchpc->updatemultiplicative = PCUpdateMultiplicative_PATCH_Nonlinear;
369: return 0;
370: }
372: PetscErrorCode SNESPatchSetDiscretisationInfo(SNES snes, PetscInt nsubspaces, DM *dms, PetscInt *bs, PetscInt *nodesPerCell, const PetscInt **cellNodeMap,
373: const PetscInt *subspaceOffsets, PetscInt numGhostBcs, const PetscInt *ghostBcNodes, PetscInt numGlobalBcs, const PetscInt *globalBcNodes)
374: {
375: SNES_Patch *patch = (SNES_Patch *) snes->data;
376: DM dm;
378: SNESGetDM(snes, &dm);
380: PCSetDM(patch->pc, dm);
381: PCPatchSetDiscretisationInfo(patch->pc, nsubspaces, dms, bs, nodesPerCell, cellNodeMap, subspaceOffsets, numGhostBcs, ghostBcNodes, numGlobalBcs, globalBcNodes);
382: return 0;
383: }
385: PetscErrorCode SNESPatchSetComputeOperator(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Mat, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
386: {
387: SNES_Patch *patch = (SNES_Patch *) snes->data;
389: PCPatchSetComputeOperator(patch->pc, func, ctx);
390: return 0;
391: }
393: PetscErrorCode SNESPatchSetComputeFunction(SNES snes, PetscErrorCode (*func)(PC, PetscInt, Vec, Vec, IS, PetscInt, const PetscInt *, const PetscInt *, void *), void *ctx)
394: {
395: SNES_Patch *patch = (SNES_Patch *) snes->data;
397: PCPatchSetComputeFunction(patch->pc, func, ctx);
398: return 0;
399: }
401: PetscErrorCode SNESPatchSetConstructType(SNES snes, PCPatchConstructType ctype, PetscErrorCode (*func)(PC, PetscInt *, IS **, IS *, void *), void *ctx)
402: {
403: SNES_Patch *patch = (SNES_Patch *) snes->data;
405: PCPatchSetConstructType(patch->pc, ctype, func, ctx);
406: return 0;
407: }
409: PetscErrorCode SNESPatchSetCellNumbering(SNES snes, PetscSection cellNumbering)
410: {
411: SNES_Patch *patch = (SNES_Patch *) snes->data;
413: PCPatchSetCellNumbering(patch->pc, cellNumbering);
414: return 0;
415: }