Actual source code: lmproducts.c
1: #include <petsc/private/petscimpl.h>
2: #include <petscmat.h>
3: #include <petscblaslapack.h>
4: #include <petscdevice.h>
5: #include "lmproducts.h"
6: #include "blas_cyclic/blas_cyclic.h"
8: PetscLogEvent LMPROD_Mult, LMPROD_Solve, LMPROD_Update;
10: PETSC_INTERN PetscErrorCode LMProductsCreate(LMBasis basis, LMBlockType block_type, LMProducts *dots)
11: {
12: PetscInt m, m_local;
14: PetscFunctionBegin;
15: PetscAssertPointer(basis, 1);
17: PetscCheck(block_type >= 0 && block_type < LMBLOCK_END, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Invalid LMBlockType");
18: PetscCall(PetscNew(dots));
19: (*dots)->m = m = basis->m;
20: (*dots)->block_type = block_type;
21: PetscCall(MatGetLocalSize(basis->vecs, NULL, &m_local));
22: (*dots)->m_local = m_local;
23: if (block_type == LMBLOCK_DIAGONAL) {
24: VecType vec_type;
26: PetscCall(MatCreateVecs(basis->vecs, &(*dots)->diagonal_global, NULL));
27: PetscCall(VecCreateLocalVector((*dots)->diagonal_global, &(*dots)->diagonal_local));
28: PetscCall(VecGetType((*dots)->diagonal_local, &vec_type));
29: PetscCall(VecCreate(PETSC_COMM_SELF, &(*dots)->diagonal_dup));
30: PetscCall(VecSetSizes((*dots)->diagonal_dup, m, m));
31: PetscCall(VecSetType((*dots)->diagonal_dup, vec_type));
32: PetscCall(VecSetUp((*dots)->diagonal_dup));
33: } else {
34: VecType vec_type;
36: PetscCall(MatGetVecType(basis->vecs, &vec_type));
37: PetscCall(MatCreateDenseFromVecType(PetscObjectComm((PetscObject)basis->vecs), vec_type, m_local, m_local, m, m, m_local, NULL, &(*dots)->full));
38: }
39: PetscFunctionReturn(PETSC_SUCCESS);
40: }
42: PETSC_INTERN PetscErrorCode LMProductsDestroy(LMProducts *dots_p)
43: {
44: PetscFunctionBegin;
45: LMProducts dots = *dots_p;
46: if (dots == NULL) PetscFunctionReturn(PETSC_SUCCESS);
47: PetscCall(MatDestroy(&dots->full));
48: PetscCall(VecDestroy(&dots->diagonal_dup));
49: PetscCall(VecDestroy(&dots->diagonal_local));
50: PetscCall(VecDestroy(&dots->diagonal_global));
51: PetscCall(VecDestroy(&dots->rhs_local));
52: PetscCall(VecDestroy(&dots->lhs_local));
53: PetscCall(PetscFree(dots));
54: PetscFunctionReturn(PETSC_SUCCESS);
55: }
57: static PetscErrorCode LMProductsPrepare_Internal(LMProducts dots, PetscObjectId operator_id, PetscObjectState operator_state, PetscInt oldest, PetscInt next)
58: {
59: PetscFunctionBegin;
60: if (dots->operator_id != operator_id || dots->operator_state != operator_state) {
61: // invalidate the block
62: dots->operator_id = operator_id;
63: dots->operator_state = operator_state;
64: dots->k = oldest;
65: }
66: dots->k = PetscMax(oldest, dots->k);
67: PetscFunctionReturn(PETSC_SUCCESS);
68: }
70: static PetscErrorCode LMProductsPrepareFromBases(LMProducts dots, LMBasis X, LMBasis Y)
71: {
72: PetscInt oldest, next;
73: PetscObjectId operator_id = (X->operator_id == 0) ? Y->operator_id : X->operator_id;
74: PetscObjectId operator_state = (X->operator_id == 0) ? Y->operator_state : X->operator_state;
76: PetscFunctionBegin;
77: PetscCall(LMBasisGetRange(X, &oldest, &next));
78: PetscCall(LMProductsPrepare_Internal(dots, operator_id, operator_state, oldest, next));
79: PetscFunctionReturn(PETSC_SUCCESS);
80: }
82: PETSC_INTERN PetscErrorCode LMProductsPrepare(LMProducts dots, Mat op, PetscInt oldest, PetscInt next)
83: {
84: PetscObjectId operator_id;
85: PetscObjectState operator_state;
87: PetscFunctionBegin;
88: PetscCall(PetscObjectGetId((PetscObject)op, &operator_id));
89: PetscCall(PetscObjectStateGet((PetscObject)op, &operator_state));
90: PetscCall(LMProductsPrepare_Internal(dots, operator_id, operator_state, oldest, next));
91: PetscFunctionReturn(PETSC_SUCCESS);
92: }
94: static PetscErrorCode LMProductsUpdate_Internal(LMProducts dots, LMBasis X, LMBasis Y, PetscInt oldest, PetscInt next)
95: {
96: MPI_Comm comm = PetscObjectComm((PetscObject)X->vecs);
97: PetscInt start;
99: PetscFunctionBegin;
100: PetscAssert(X->m == Y->m && X->m == dots->m, comm, PETSC_ERR_ARG_INCOMP, "X vecs, Y vecs, and dots incompatible in size, (%d, %d, %d)", (int)X->m, (int)Y->m, (int)dots->m);
101: PetscAssert(X->k == Y->k, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are incompatible in state, (%d, %d)", (int)X->k, (int)Y->k);
102: PetscAssert(dots->k <= X->k, comm, PETSC_ERR_ARG_INCOMP, "Dot products are ahead of X and Y, (%d, %d)", (int)dots->k, (int)X->k);
103: PetscAssert(X->operator_id == 0 || Y->operator_id == 0 || X->operator_id == Y->operator_id, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are from different operators");
104: PetscAssert(X->operator_id != Y->operator_id || Y->operator_state == X->operator_state, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are from different operator states");
106: PetscCall(LMProductsPrepareFromBases(dots, X, Y));
108: start = dots->k;
109: if (start == next) PetscFunctionReturn(PETSC_SUCCESS);
110: PetscCall(PetscLogEventBegin(LMPROD_Update, NULL, NULL, NULL, NULL));
111: switch (dots->block_type) {
112: case LMBLOCK_DIAGONAL:
113: for (PetscInt i = start; i < next; i++) {
114: Vec x, y;
115: PetscScalar xTy;
117: PetscCall(LMBasisGetVecRead(X, i, &x));
118: y = x;
119: if (Y != X) PetscCall(LMBasisGetVecRead(Y, i, &y));
120: PetscCall(VecDot(y, x, &xTy));
121: if (Y != X) PetscCall(LMBasisRestoreVecRead(Y, i, &y));
122: PetscCall(LMBasisRestoreVecRead(X, i, &x));
123: PetscCall(LMProductsInsertNextDiagonalValue(dots, i, xTy));
124: }
125: break;
126: case LMBLOCK_STRICT_UPPER_TRIANGLE: {
127: Mat local;
129: PetscCall(MatDenseGetLocalMatrix(dots->full, &local));
130: // we have to proceed index by index because we want to zero each row after we compute the corresponding column
131: for (PetscInt i = start; i < next; i++) {
132: Mat row;
133: Vec column, y;
135: PetscCall(LMBasisGetVecRead(Y, i, &y));
136: PetscCall(MatDenseGetColumnVec(dots->full, i % dots->m, &column));
137: PetscCall(LMBasisGEMVH(X, oldest, next, 1.0, y, 0.0, column));
138: PetscCall(MatDenseRestoreColumnVec(dots->full, i % dots->m, &column));
139: PetscCall(LMBasisRestoreVecRead(Y, i, &y));
141: // zero out the new row
142: if (dots->m_local) {
143: PetscCall(MatDenseGetSubMatrix(local, i % dots->m, (i % dots->m) + 1, PETSC_DECIDE, PETSC_DECIDE, &row));
144: PetscCall(MatZeroEntries(row));
145: PetscCall(MatDenseRestoreSubMatrix(local, &row));
146: }
147: }
148: } break;
149: case LMBLOCK_UPPER_TRIANGLE: {
150: PetscInt mid = next - (next % dots->m);
151: PetscInt start_idx = start % dots->m;
152: PetscInt next_idx = ((next - 1) % dots->m) + 1;
154: if (next_idx > start_idx) {
155: PetscCall(LMBasisGEMMH(X, oldest, next, Y, start, next, 1.0, 0.0, dots->full));
156: } else {
157: PetscCall(LMBasisGEMMH(X, oldest, mid, Y, start, mid, 1.0, 0.0, dots->full));
158: PetscCall(LMBasisGEMMH(X, oldest, next, Y, mid, next, 1.0, 0.0, dots->full));
159: }
160: } break;
161: case LMBLOCK_FULL:
162: PetscCall(LMBasisGEMMH(X, oldest, next, Y, start, next, 1.0, 0.0, dots->full));
163: PetscCall(LMBasisGEMMH(X, start, next, Y, oldest, start, 1.0, 0.0, dots->full));
164: break;
165: default:
166: PetscUnreachable();
167: }
168: dots->k = next;
169: if (dots->debug) {
170: const PetscScalar *values = NULL;
171: PetscInt lda;
172: PetscInt N;
174: PetscCall(MatGetSize(X->vecs, &N, NULL));
175: if (dots->block_type == LMBLOCK_DIAGONAL) {
176: lda = 0;
177: if (dots->update_diagonal_global) {
178: PetscCall(VecGetArrayRead(dots->diagonal_global, &values));
179: } else {
180: PetscCall(VecGetArrayRead(dots->diagonal_dup, &values));
181: }
182: } else {
183: PetscCall(MatDenseGetLDA(dots->full, &lda));
184: PetscCall(MatDenseGetArrayRead(dots->full, &values));
185: }
186: for (PetscInt i = oldest; i < next; i++) {
187: Vec x_i_, x_i;
188: PetscReal x_norm;
189: PetscInt j_start = oldest;
190: PetscInt j_end = next;
192: PetscCall(LMBasisGetVecRead(X, i, &x_i_));
193: PetscCall(VecNorm(x_i_, NORM_1, &x_norm));
194: PetscCall(VecDuplicate(x_i_, &x_i));
195: PetscCall(VecCopy(x_i_, x_i));
196: PetscCall(LMBasisRestoreVecRead(X, i, &x_i_));
198: switch (dots->block_type) {
199: case LMBLOCK_DIAGONAL:
200: j_start = i;
201: j_end = i + 1;
202: break;
203: case LMBLOCK_UPPER_TRIANGLE:
204: j_start = i;
205: break;
206: case LMBLOCK_STRICT_UPPER_TRIANGLE:
207: j_start = i + 1;
208: break;
209: default:
210: break;
211: }
212: for (PetscInt j = j_start; j < j_end; j++) {
213: Vec y_j;
214: PetscScalar dot_true, dot = 0.0, diff;
215: PetscReal y_norm;
217: PetscCall(LMBasisGetVecRead(Y, j, &y_j));
218: PetscCall(VecDot(y_j, x_i, &dot_true));
219: PetscCall(VecNorm(y_j, NORM_1, &y_norm));
220: if (dots->m_local) dot = values[(j % dots->m) * lda + (i % dots->m)];
221: PetscCallMPI(MPI_Bcast(&dot, 1, MPIU_SCALAR, 0, comm));
222: diff = dot_true - dot;
223: if (PetscDefined(USE_COMPLEX)) {
224: PetscCheck(PetscAbsScalar(diff) <= PETSC_SMALL * N * x_norm * y_norm, comm, PETSC_ERR_PLIB, "LMProducts debug: dots[%" PetscInt_FMT ", %" PetscInt_FMT "] = %g + i*%g != VecDot() = %g + i*%g", i, j, (double)PetscRealPart(dot), (double)PetscImaginaryPart(dot), (double)PetscRealPart(dot_true), (double)PetscImaginaryPart(dot_true));
225: } else {
226: PetscCheck(PetscAbsScalar(diff) <= PETSC_SMALL * N * x_norm * y_norm, comm, PETSC_ERR_PLIB, "LMProducts debug: dots[%" PetscInt_FMT ", %" PetscInt_FMT "] = %g != VecDot() = %g", i, j, (double)PetscRealPart(dot), (double)PetscRealPart(dot_true));
227: }
228: PetscCall(LMBasisRestoreVecRead(Y, j, &y_j));
229: }
231: PetscCall(VecDestroy(&x_i));
232: }
234: if (dots->block_type == LMBLOCK_DIAGONAL) {
235: if (dots->update_diagonal_global) {
236: PetscCall(VecRestoreArrayRead(dots->diagonal_global, &values));
237: } else {
238: PetscCall(VecRestoreArrayRead(dots->diagonal_dup, &values));
239: }
240: } else {
241: PetscCall(MatDenseRestoreArrayRead(dots->full, &values));
242: }
243: }
244: PetscCall(PetscLogEventEnd(LMPROD_Update, NULL, NULL, NULL, NULL));
245: PetscFunctionReturn(PETSC_SUCCESS);
246: }
248: // dots = X^H Y
249: PETSC_INTERN PetscErrorCode LMProductsUpdate(LMProducts dots, LMBasis X, LMBasis Y)
250: {
251: PetscInt oldest, next;
253: PetscFunctionBegin;
254: PetscCall(LMBasisGetRange(X, &oldest, &next));
255: PetscCall(LMProductsUpdate_Internal(dots, X, Y, oldest, next));
256: PetscFunctionReturn(PETSC_SUCCESS);
257: }
259: PETSC_INTERN PetscErrorCode LMProductsCopy(LMProducts src, LMProducts dest)
260: {
261: PetscFunctionBegin;
262: PetscCheck(dest->m == src->m, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different size");
263: PetscCheck(dest->m_local == src->m_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different size");
264: PetscCheck(dest->block_type == src->block_type, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different block type");
265: dest->k = src->k;
266: dest->m_local = src->m_local;
267: if (src->full) PetscCall(MatCopy(src->full, dest->full, DIFFERENT_NONZERO_PATTERN));
268: if (src->diagonal_dup) PetscCall(VecCopy(src->diagonal_dup, dest->diagonal_dup));
269: if (src->diagonal_global) PetscCall(VecCopy(src->diagonal_global, dest->diagonal_global));
270: dest->update_diagonal_global = src->update_diagonal_global;
271: dest->operator_id = src->operator_id;
272: dest->operator_state = src->operator_state;
273: PetscFunctionReturn(PETSC_SUCCESS);
274: }
276: PETSC_INTERN PetscErrorCode LMProductsScale(LMProducts dots, PetscScalar scale)
277: {
278: PetscFunctionBegin;
279: if (dots->full) PetscCall(MatScale(dots->full, scale));
280: if (dots->diagonal_dup) PetscCall(VecScale(dots->diagonal_dup, scale));
281: if (dots->diagonal_global) PetscCall(VecScale(dots->diagonal_global, scale));
282: PetscFunctionReturn(PETSC_SUCCESS);
283: }
285: PETSC_INTERN PetscErrorCode LMProductsGetLocalMatrix(LMProducts dots, Mat *G_local, PetscInt *k, PetscBool *local_is_nonempty)
286: {
287: PetscFunctionBegin;
288: PetscCheck(dots->block_type != LMBLOCK_DIAGONAL, PETSC_COMM_SELF, PETSC_ERR_SUP, "Asking for full matrix of diagonal products");
289: PetscCall(MatDenseGetLocalMatrix(dots->full, G_local));
290: if (k) *k = dots->k;
291: if (local_is_nonempty) *local_is_nonempty = (dots->m_local == dots->m) ? PETSC_TRUE : PETSC_FALSE;
292: PetscFunctionReturn(PETSC_SUCCESS);
293: }
295: PETSC_INTERN PetscErrorCode LMProductsRestoreLocalMatrix(LMProducts dots, Mat *G_local, PetscInt *k)
296: {
297: PetscFunctionBegin;
298: if (G_local) *G_local = NULL;
299: if (k) dots->k = *k;
300: PetscFunctionReturn(PETSC_SUCCESS);
301: }
303: static PetscErrorCode LMProductsGetUpdatedDiagonal(LMProducts dots, Vec *diagonal)
304: {
305: PetscFunctionBegin;
306: if (!dots->update_diagonal_global) {
307: PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
308: if (dots->m_local) PetscCall(VecCopy(dots->diagonal_dup, dots->diagonal_local));
309: PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
310: dots->update_diagonal_global = PETSC_TRUE;
311: }
312: if (diagonal) *diagonal = dots->diagonal_global;
313: PetscFunctionReturn(PETSC_SUCCESS);
314: }
316: PETSC_INTERN PetscErrorCode LMProductsGetLocalDiagonal(LMProducts dots, Vec *D_local)
317: {
318: PetscFunctionBegin;
319: PetscCall(LMProductsGetUpdatedDiagonal(dots, NULL));
320: PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
321: *D_local = dots->diagonal_local;
322: PetscFunctionReturn(PETSC_SUCCESS);
323: }
325: PETSC_INTERN PetscErrorCode LMProductsRestoreLocalDiagonal(LMProducts dots, Vec *D_local)
326: {
327: PetscFunctionBegin;
328: PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
329: *D_local = NULL;
330: PetscFunctionReturn(PETSC_SUCCESS);
331: }
333: PETSC_INTERN PetscErrorCode LMProductsGetNextColumn(LMProducts dots, Vec *col)
334: {
335: PetscFunctionBegin;
336: PetscCheck(dots->block_type != LMBLOCK_DIAGONAL, PETSC_COMM_SELF, PETSC_ERR_SUP, "Asking for column of diagonal products");
337: PetscCall(MatDenseGetColumnVecWrite(dots->full, dots->k % dots->m, col));
338: PetscFunctionReturn(PETSC_SUCCESS);
339: }
341: PETSC_INTERN PetscErrorCode LMProductsRestoreNextColumn(LMProducts dots, Vec *col)
342: {
343: PetscFunctionBegin;
344: PetscCall(MatDenseRestoreColumnVecWrite(dots->full, dots->k % dots->m, col));
345: dots->k++;
346: PetscFunctionReturn(PETSC_SUCCESS);
347: }
349: // copy conj(triu(G)) into tril(G)
350: PETSC_INTERN PetscErrorCode LMProductsMakeHermitian(Mat local, PetscInt oldest, PetscInt next)
351: {
352: PetscInt m;
354: PetscFunctionBegin;
355: PetscCall(MatGetLocalSize(local, &m, NULL));
356: if (m) {
357: // TODO: implement on device?
358: PetscScalar *a;
359: PetscInt lda;
361: PetscCall(MatDenseGetLDA(local, &lda));
362: PetscCall(MatDenseGetArray(local, &a));
363: for (PetscInt j_ = oldest; j_ < next; j_++) {
364: PetscInt j = j_ % m;
366: a[j + j * lda] = PetscRealPart(a[j + j * lda]);
367: for (PetscInt i_ = j_ + 1; i_ < next; i_++) {
368: PetscInt i = i_ % m;
370: a[i + j * lda] = PetscConj(a[j + i * lda]);
371: }
372: }
373: }
374: PetscFunctionReturn(PETSC_SUCCESS);
375: }
377: PETSC_INTERN PetscErrorCode LMProductsSolve(LMProducts dots, PetscInt oldest, PetscInt next, Vec b, Vec x, PetscBool hermitian_transpose)
378: {
379: PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
380: PetscInt dots_next = dots->k;
381: Mat local;
382: Vec diag = NULL;
384: PetscFunctionBegin;
385: PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
386: if (oldest >= next) PetscFunctionReturn(PETSC_SUCCESS);
387: PetscCall(PetscLogEventBegin(LMPROD_Solve, NULL, NULL, NULL, NULL));
388: if (!dots->rhs_local) PetscCall(VecCreateLocalVector(b, &dots->rhs_local));
389: if (!dots->lhs_local) PetscCall(VecDuplicate(dots->rhs_local, &dots->lhs_local));
390: switch (dots->block_type) {
391: case LMBLOCK_DIAGONAL:
392: PetscCall(LMProductsGetUpdatedDiagonal(dots, &diag));
393: PetscCall(VecDSVCyclic(hermitian_transpose, oldest, next, diag, b, x));
394: break;
395: case LMBLOCK_UPPER_TRIANGLE:
396: PetscCall(MatSeqDenseTRSVCyclic(hermitian_transpose, oldest, next, dots->full, b, x));
397: break;
398: default: {
399: PetscCall(MatDenseGetLocalMatrix(dots->full, &local));
400: PetscCall(VecGetLocalVector(b, dots->rhs_local));
401: PetscCall(VecGetLocalVector(x, dots->lhs_local));
402: if (dots->m_local) {
403: if (!hermitian_transpose) {
404: PetscCall(MatSolve(local, dots->rhs_local, dots->lhs_local));
405: } else {
406: Vec rhs_conj = dots->rhs_local;
408: if (PetscDefined(USE_COMPLEX)) {
409: PetscCall(VecDuplicate(dots->rhs_local, &rhs_conj));
410: PetscCall(VecCopy(dots->rhs_local, rhs_conj));
411: PetscCall(VecConjugate(rhs_conj));
412: }
413: PetscCall(MatSolveTranspose(local, rhs_conj, dots->lhs_local));
414: if (PetscDefined(USE_COMPLEX)) {
415: PetscCall(VecConjugate(dots->lhs_local));
416: PetscCall(VecDestroy(&rhs_conj));
417: }
418: }
419: }
420: if (x != b) PetscCall(VecRestoreLocalVector(x, dots->lhs_local));
421: PetscCall(VecRestoreLocalVector(b, dots->rhs_local));
422: } break;
423: }
424: PetscCall(PetscLogEventEnd(LMPROD_Solve, NULL, NULL, NULL, NULL));
425: PetscFunctionReturn(PETSC_SUCCESS);
426: }
428: PETSC_INTERN PetscErrorCode LMProductsMult(LMProducts dots, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y, PetscBool hermitian_transpose)
429: {
430: PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
431: PetscInt dots_next = dots->k;
432: Vec diag = NULL;
434: PetscFunctionBegin;
435: PetscCall(PetscLogEventBegin(LMPROD_Mult, NULL, NULL, NULL, NULL));
436: PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
437: switch (dots->block_type) {
438: case LMBLOCK_DIAGONAL: {
439: PetscCall(LMProductsGetUpdatedDiagonal(dots, &diag));
440: PetscCall(VecDMVCyclic(hermitian_transpose, oldest, next, alpha, diag, x, beta, y));
441: } break;
442: case LMBLOCK_STRICT_UPPER_TRIANGLE: // the lower triangle has been zeroed, MatMult() is safe
443: case LMBLOCK_FULL:
444: PetscCall(MatSeqDenseGEMVCyclic(hermitian_transpose, oldest, next, alpha, dots->full, x, beta, y));
445: break;
446: default:
447: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
448: }
449: PetscCall(PetscLogEventEnd(LMPROD_Mult, NULL, NULL, NULL, NULL));
450: PetscFunctionReturn(PETSC_SUCCESS);
451: }
453: PETSC_INTERN PetscErrorCode LMProductsMultHermitian(LMProducts dots, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
454: {
455: PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
456: PetscInt dots_next = dots->k;
458: PetscFunctionBegin;
459: PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
460: if (dots->block_type == LMBLOCK_DIAGONAL) PetscCall(LMProductsMult(dots, oldest, next, alpha, x, beta, y, PETSC_FALSE));
461: else {
462: PetscCall(PetscLogEventBegin(LMPROD_Mult, NULL, NULL, NULL, NULL));
463: PetscCall(MatSeqDenseHEMVCyclic(oldest, next, alpha, dots->full, x, beta, y));
464: PetscCall(PetscLogEventEnd(LMPROD_Mult, NULL, NULL, NULL, NULL));
465: }
466: PetscFunctionReturn(PETSC_SUCCESS);
467: }
469: PETSC_INTERN PetscErrorCode LMProductsReset(LMProducts dots)
470: {
471: PetscFunctionBegin;
472: if (dots) {
473: dots->k = 0;
474: dots->operator_id = 0;
475: dots->operator_state = 0;
476: if (dots->full) {
477: Mat full_local;
479: PetscCall(MatDenseGetLocalMatrix(dots->full, &full_local));
480: PetscCall(MatSetUnfactored(full_local));
481: PetscCall(MatZeroEntries(full_local));
482: }
483: if (dots->diagonal_global) PetscCall(VecZeroEntries(dots->diagonal_dup));
484: if (dots->diagonal_dup) PetscCall(VecZeroEntries(dots->diagonal_dup));
485: }
486: PetscFunctionReturn(PETSC_SUCCESS);
487: }
489: PETSC_INTERN PetscErrorCode LMProductsGetDiagonalValue(LMProducts dots, PetscInt i, PetscScalar *v)
490: {
491: PetscFunctionBegin;
492: PetscInt oldest = PetscMax(0, dots->k - dots->m);
493: PetscInt next = dots->k;
494: PetscInt idx = i % dots->m;
495: PetscCheck(i >= oldest && i < next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Inserting value %d out of range [%d, %d)", (int)i, (int)oldest, (int)next);
496: PetscCall(VecGetValues(dots->diagonal_dup, 1, &idx, v));
497: PetscFunctionReturn(PETSC_SUCCESS);
498: }
500: PETSC_INTERN PetscErrorCode LMProductsInsertNextDiagonalValue(LMProducts dots, PetscInt i, PetscScalar v)
501: {
502: PetscInt idx = i % dots->m;
504: PetscFunctionBegin;
505: PetscCheck(i == dots->k, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is not the next index (%" PetscInt_FMT ")", i, dots->k);
506: PetscCall(VecSetValue(dots->diagonal_dup, idx, v, INSERT_VALUES));
507: if (dots->update_diagonal_global) {
508: PetscScalar *array;
509: PetscMemType memtype;
511: PetscCall(VecGetArrayAndMemType(dots->diagonal_global, &array, &memtype));
512: if (dots->m_local > 0) {
513: if (PetscMemTypeHost(memtype)) {
514: array[idx] = v;
515: PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
516: } else {
517: PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
518: PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
519: if (dots->m_local) PetscCall(VecCopy(dots->diagonal_dup, dots->diagonal_local));
520: PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
521: }
522: } else {
523: PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
524: }
525: }
526: dots->k++;
527: PetscFunctionReturn(PETSC_SUCCESS);
528: }
530: PETSC_INTERN PetscErrorCode LMProductsOnesOnUnusedDiagonal(Mat A, PetscInt oldest, PetscInt next)
531: {
532: PetscInt m;
533: Mat sub;
535: PetscFunctionBegin;
536: PetscCall(MatGetSize(A, &m, NULL));
537: // we could handle the general case but this is the only case used by MatLMVM
538: PetscCheck((next < m && oldest == 0) || next - oldest == m, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "General case not implemented");
539: if (next - oldest == m) PetscFunctionReturn(PETSC_SUCCESS); // nothing to do if all entries are used
540: PetscCall(MatDenseGetSubMatrix(A, next, m, next, m, &sub));
541: PetscCall(MatShift(sub, 1.0));
542: PetscCall(MatDenseRestoreSubMatrix(A, &sub));
543: PetscFunctionReturn(PETSC_SUCCESS);
544: }