Actual source code: lmproducts.c

  1: #include <petsc/private/petscimpl.h>
  2: #include <petscmat.h>
  3: #include <petscblaslapack.h>
  4: #include <petscdevice.h>
  5: #include "lmproducts.h"
  6: #include "blas_cyclic/blas_cyclic.h"

  8: PetscLogEvent LMPROD_Mult, LMPROD_Solve, LMPROD_Update;

 10: PETSC_INTERN PetscErrorCode LMProductsCreate(LMBasis basis, LMBlockType block_type, LMProducts *dots)
 11: {
 12:   PetscInt m, m_local;

 14:   PetscFunctionBegin;
 15:   PetscAssertPointer(basis, 1);
 17:   PetscCheck(block_type >= 0 && block_type < LMBLOCK_END, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Invalid LMBlockType");
 18:   PetscCall(PetscNew(dots));
 19:   (*dots)->m = m      = basis->m;
 20:   (*dots)->block_type = block_type;
 21:   PetscCall(MatGetLocalSize(basis->vecs, NULL, &m_local));
 22:   (*dots)->m_local = m_local;
 23:   if (block_type == LMBLOCK_DIAGONAL) {
 24:     VecType vec_type;

 26:     PetscCall(MatCreateVecs(basis->vecs, &(*dots)->diagonal_global, NULL));
 27:     PetscCall(VecCreateLocalVector((*dots)->diagonal_global, &(*dots)->diagonal_local));
 28:     PetscCall(VecGetType((*dots)->diagonal_local, &vec_type));
 29:     PetscCall(VecCreate(PETSC_COMM_SELF, &(*dots)->diagonal_dup));
 30:     PetscCall(VecSetSizes((*dots)->diagonal_dup, m, m));
 31:     PetscCall(VecSetType((*dots)->diagonal_dup, vec_type));
 32:     PetscCall(VecSetUp((*dots)->diagonal_dup));
 33:   } else {
 34:     VecType vec_type;

 36:     PetscCall(MatGetVecType(basis->vecs, &vec_type));
 37:     PetscCall(MatCreateDenseFromVecType(PetscObjectComm((PetscObject)basis->vecs), vec_type, m_local, m_local, m, m, m_local, NULL, &(*dots)->full));
 38:   }
 39:   PetscFunctionReturn(PETSC_SUCCESS);
 40: }

 42: PETSC_INTERN PetscErrorCode LMProductsDestroy(LMProducts *dots_p)
 43: {
 44:   PetscFunctionBegin;
 45:   LMProducts dots = *dots_p;
 46:   if (dots == NULL) PetscFunctionReturn(PETSC_SUCCESS);
 47:   PetscCall(MatDestroy(&dots->full));
 48:   PetscCall(VecDestroy(&dots->diagonal_dup));
 49:   PetscCall(VecDestroy(&dots->diagonal_local));
 50:   PetscCall(VecDestroy(&dots->diagonal_global));
 51:   PetscCall(VecDestroy(&dots->rhs_local));
 52:   PetscCall(VecDestroy(&dots->lhs_local));
 53:   PetscCall(PetscFree(dots));
 54:   PetscFunctionReturn(PETSC_SUCCESS);
 55: }

 57: static PetscErrorCode LMProductsPrepare_Internal(LMProducts dots, PetscObjectId operator_id, PetscObjectState operator_state, PetscInt oldest, PetscInt next)
 58: {
 59:   PetscFunctionBegin;
 60:   if (dots->operator_id != operator_id || dots->operator_state != operator_state) {
 61:     // invalidate the block
 62:     dots->operator_id    = operator_id;
 63:     dots->operator_state = operator_state;
 64:     dots->k              = oldest;
 65:   }
 66:   dots->k = PetscMax(oldest, dots->k);
 67:   PetscFunctionReturn(PETSC_SUCCESS);
 68: }

 70: static PetscErrorCode LMProductsPrepareFromBases(LMProducts dots, LMBasis X, LMBasis Y)
 71: {
 72:   PetscInt      oldest, next;
 73:   PetscObjectId operator_id    = (X->operator_id == 0) ? Y->operator_id : X->operator_id;
 74:   PetscObjectId operator_state = (X->operator_id == 0) ? Y->operator_state : X->operator_state;

 76:   PetscFunctionBegin;
 77:   PetscCall(LMBasisGetRange(X, &oldest, &next));
 78:   PetscCall(LMProductsPrepare_Internal(dots, operator_id, operator_state, oldest, next));
 79:   PetscFunctionReturn(PETSC_SUCCESS);
 80: }

 82: PETSC_INTERN PetscErrorCode LMProductsPrepare(LMProducts dots, Mat op, PetscInt oldest, PetscInt next)
 83: {
 84:   PetscObjectId    operator_id;
 85:   PetscObjectState operator_state;

 87:   PetscFunctionBegin;
 88:   PetscCall(PetscObjectGetId((PetscObject)op, &operator_id));
 89:   PetscCall(PetscObjectStateGet((PetscObject)op, &operator_state));
 90:   PetscCall(LMProductsPrepare_Internal(dots, operator_id, operator_state, oldest, next));
 91:   PetscFunctionReturn(PETSC_SUCCESS);
 92: }

 94: static PetscErrorCode LMProductsUpdate_Internal(LMProducts dots, LMBasis X, LMBasis Y, PetscInt oldest, PetscInt next)
 95: {
 96:   MPI_Comm comm = PetscObjectComm((PetscObject)X->vecs);
 97:   PetscInt start;

 99:   PetscFunctionBegin;
100:   PetscAssert(X->m == Y->m && X->m == dots->m, comm, PETSC_ERR_ARG_INCOMP, "X vecs, Y vecs, and dots incompatible in size, (%d, %d, %d)", (int)X->m, (int)Y->m, (int)dots->m);
101:   PetscAssert(X->k == Y->k, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are incompatible in state, (%d, %d)", (int)X->k, (int)Y->k);
102:   PetscAssert(dots->k <= X->k, comm, PETSC_ERR_ARG_INCOMP, "Dot products are ahead of X and Y, (%d, %d)", (int)dots->k, (int)X->k);
103:   PetscAssert(X->operator_id == 0 || Y->operator_id == 0 || X->operator_id == Y->operator_id, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are from different operators");
104:   PetscAssert(X->operator_id != Y->operator_id || Y->operator_state == X->operator_state, comm, PETSC_ERR_ARG_INCOMP, "X and Y vecs are from different operator states");

106:   PetscCall(LMProductsPrepareFromBases(dots, X, Y));

108:   start = dots->k;
109:   if (start == next) PetscFunctionReturn(PETSC_SUCCESS);
110:   PetscCall(PetscLogEventBegin(LMPROD_Update, NULL, NULL, NULL, NULL));
111:   switch (dots->block_type) {
112:   case LMBLOCK_DIAGONAL:
113:     for (PetscInt i = start; i < next; i++) {
114:       Vec         x, y;
115:       PetscScalar xTy;

117:       PetscCall(LMBasisGetVecRead(X, i, &x));
118:       y = x;
119:       if (Y != X) PetscCall(LMBasisGetVecRead(Y, i, &y));
120:       PetscCall(VecDot(y, x, &xTy));
121:       if (Y != X) PetscCall(LMBasisRestoreVecRead(Y, i, &y));
122:       PetscCall(LMBasisRestoreVecRead(X, i, &x));
123:       PetscCall(LMProductsInsertNextDiagonalValue(dots, i, xTy));
124:     }
125:     break;
126:   case LMBLOCK_STRICT_UPPER_TRIANGLE: {
127:     Mat local;

129:     PetscCall(MatDenseGetLocalMatrix(dots->full, &local));
130:     // we have to proceed index by index because we want to zero each row after we compute the corresponding column
131:     for (PetscInt i = start; i < next; i++) {
132:       Mat row;
133:       Vec column, y;

135:       PetscCall(LMBasisGetVecRead(Y, i, &y));
136:       PetscCall(MatDenseGetColumnVec(dots->full, i % dots->m, &column));
137:       PetscCall(LMBasisGEMVH(X, oldest, next, 1.0, y, 0.0, column));
138:       PetscCall(MatDenseRestoreColumnVec(dots->full, i % dots->m, &column));
139:       PetscCall(LMBasisRestoreVecRead(Y, i, &y));

141:       // zero out the new row
142:       if (dots->m_local) {
143:         PetscCall(MatDenseGetSubMatrix(local, i % dots->m, (i % dots->m) + 1, PETSC_DECIDE, PETSC_DECIDE, &row));
144:         PetscCall(MatZeroEntries(row));
145:         PetscCall(MatDenseRestoreSubMatrix(local, &row));
146:       }
147:     }
148:   } break;
149:   case LMBLOCK_UPPER_TRIANGLE: {
150:     PetscInt mid       = next - (next % dots->m);
151:     PetscInt start_idx = start % dots->m;
152:     PetscInt next_idx  = ((next - 1) % dots->m) + 1;

154:     if (next_idx > start_idx) {
155:       PetscCall(LMBasisGEMMH(X, oldest, next, Y, start, next, 1.0, 0.0, dots->full));
156:     } else {
157:       PetscCall(LMBasisGEMMH(X, oldest, mid, Y, start, mid, 1.0, 0.0, dots->full));
158:       PetscCall(LMBasisGEMMH(X, oldest, next, Y, mid, next, 1.0, 0.0, dots->full));
159:     }
160:   } break;
161:   case LMBLOCK_FULL:
162:     PetscCall(LMBasisGEMMH(X, oldest, next, Y, start, next, 1.0, 0.0, dots->full));
163:     PetscCall(LMBasisGEMMH(X, start, next, Y, oldest, start, 1.0, 0.0, dots->full));
164:     break;
165:   default:
166:     PetscUnreachable();
167:   }
168:   dots->k = next;
169:   if (dots->debug) {
170:     const PetscScalar *values = NULL;
171:     PetscInt           lda;
172:     PetscInt           N;

174:     PetscCall(MatGetSize(X->vecs, &N, NULL));
175:     if (dots->block_type == LMBLOCK_DIAGONAL) {
176:       lda = 0;
177:       if (dots->update_diagonal_global) {
178:         PetscCall(VecGetArrayRead(dots->diagonal_global, &values));
179:       } else {
180:         PetscCall(VecGetArrayRead(dots->diagonal_dup, &values));
181:       }
182:     } else {
183:       PetscCall(MatDenseGetLDA(dots->full, &lda));
184:       PetscCall(MatDenseGetArrayRead(dots->full, &values));
185:     }
186:     for (PetscInt i = oldest; i < next; i++) {
187:       Vec       x_i_, x_i;
188:       PetscReal x_norm;
189:       PetscInt  j_start = oldest;
190:       PetscInt  j_end   = next;

192:       PetscCall(LMBasisGetVecRead(X, i, &x_i_));
193:       PetscCall(VecNorm(x_i_, NORM_1, &x_norm));
194:       PetscCall(VecDuplicate(x_i_, &x_i));
195:       PetscCall(VecCopy(x_i_, x_i));
196:       PetscCall(LMBasisRestoreVecRead(X, i, &x_i_));

198:       switch (dots->block_type) {
199:       case LMBLOCK_DIAGONAL:
200:         j_start = i;
201:         j_end   = i + 1;
202:         break;
203:       case LMBLOCK_UPPER_TRIANGLE:
204:         j_start = i;
205:         break;
206:       case LMBLOCK_STRICT_UPPER_TRIANGLE:
207:         j_start = i + 1;
208:         break;
209:       default:
210:         break;
211:       }
212:       for (PetscInt j = j_start; j < j_end; j++) {
213:         Vec         y_j;
214:         PetscScalar dot_true, dot = 0.0, diff;
215:         PetscReal   y_norm;

217:         PetscCall(LMBasisGetVecRead(Y, j, &y_j));
218:         PetscCall(VecDot(y_j, x_i, &dot_true));
219:         PetscCall(VecNorm(y_j, NORM_1, &y_norm));
220:         if (dots->m_local) dot = values[(j % dots->m) * lda + (i % dots->m)];
221:         PetscCallMPI(MPI_Bcast(&dot, 1, MPIU_SCALAR, 0, comm));
222:         diff = dot_true - dot;
223:         if (PetscDefined(USE_COMPLEX)) {
224:           PetscCheck(PetscAbsScalar(diff) <= PETSC_SMALL * N * x_norm * y_norm, comm, PETSC_ERR_PLIB, "LMProducts debug: dots[%" PetscInt_FMT ", %" PetscInt_FMT "] = %g + i*%g != VecDot() = %g + i*%g", i, j, (double)PetscRealPart(dot), (double)PetscImaginaryPart(dot), (double)PetscRealPart(dot_true), (double)PetscImaginaryPart(dot_true));
225:         } else {
226:           PetscCheck(PetscAbsScalar(diff) <= PETSC_SMALL * N * x_norm * y_norm, comm, PETSC_ERR_PLIB, "LMProducts debug: dots[%" PetscInt_FMT ", %" PetscInt_FMT "] = %g != VecDot() = %g", i, j, (double)PetscRealPart(dot), (double)PetscRealPart(dot_true));
227:         }
228:         PetscCall(LMBasisRestoreVecRead(Y, j, &y_j));
229:       }

231:       PetscCall(VecDestroy(&x_i));
232:     }

234:     if (dots->block_type == LMBLOCK_DIAGONAL) {
235:       if (dots->update_diagonal_global) {
236:         PetscCall(VecRestoreArrayRead(dots->diagonal_global, &values));
237:       } else {
238:         PetscCall(VecRestoreArrayRead(dots->diagonal_dup, &values));
239:       }
240:     } else {
241:       PetscCall(MatDenseRestoreArrayRead(dots->full, &values));
242:     }
243:   }
244:   PetscCall(PetscLogEventEnd(LMPROD_Update, NULL, NULL, NULL, NULL));
245:   PetscFunctionReturn(PETSC_SUCCESS);
246: }

248: // dots = X^H Y
249: PETSC_INTERN PetscErrorCode LMProductsUpdate(LMProducts dots, LMBasis X, LMBasis Y)
250: {
251:   PetscInt oldest, next;

253:   PetscFunctionBegin;
254:   PetscCall(LMBasisGetRange(X, &oldest, &next));
255:   PetscCall(LMProductsUpdate_Internal(dots, X, Y, oldest, next));
256:   PetscFunctionReturn(PETSC_SUCCESS);
257: }

259: PETSC_INTERN PetscErrorCode LMProductsCopy(LMProducts src, LMProducts dest)
260: {
261:   PetscFunctionBegin;
262:   PetscCheck(dest->m == src->m, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different size");
263:   PetscCheck(dest->m_local == src->m_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different size");
264:   PetscCheck(dest->block_type == src->block_type, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Cannot copy to LMProducts of different block type");
265:   dest->k       = src->k;
266:   dest->m_local = src->m_local;
267:   if (src->full) PetscCall(MatCopy(src->full, dest->full, DIFFERENT_NONZERO_PATTERN));
268:   if (src->diagonal_dup) PetscCall(VecCopy(src->diagonal_dup, dest->diagonal_dup));
269:   if (src->diagonal_global) PetscCall(VecCopy(src->diagonal_global, dest->diagonal_global));
270:   dest->update_diagonal_global = src->update_diagonal_global;
271:   dest->operator_id            = src->operator_id;
272:   dest->operator_state         = src->operator_state;
273:   PetscFunctionReturn(PETSC_SUCCESS);
274: }

276: PETSC_INTERN PetscErrorCode LMProductsScale(LMProducts dots, PetscScalar scale)
277: {
278:   PetscFunctionBegin;
279:   if (dots->full) PetscCall(MatScale(dots->full, scale));
280:   if (dots->diagonal_dup) PetscCall(VecScale(dots->diagonal_dup, scale));
281:   if (dots->diagonal_global) PetscCall(VecScale(dots->diagonal_global, scale));
282:   PetscFunctionReturn(PETSC_SUCCESS);
283: }

285: PETSC_INTERN PetscErrorCode LMProductsGetLocalMatrix(LMProducts dots, Mat *G_local, PetscInt *k, PetscBool *local_is_nonempty)
286: {
287:   PetscFunctionBegin;
288:   PetscCheck(dots->block_type != LMBLOCK_DIAGONAL, PETSC_COMM_SELF, PETSC_ERR_SUP, "Asking for full matrix of diagonal products");
289:   PetscCall(MatDenseGetLocalMatrix(dots->full, G_local));
290:   if (k) *k = dots->k;
291:   if (local_is_nonempty) *local_is_nonempty = (dots->m_local == dots->m) ? PETSC_TRUE : PETSC_FALSE;
292:   PetscFunctionReturn(PETSC_SUCCESS);
293: }

295: PETSC_INTERN PetscErrorCode LMProductsRestoreLocalMatrix(LMProducts dots, Mat *G_local, PetscInt *k)
296: {
297:   PetscFunctionBegin;
298:   if (G_local) *G_local = NULL;
299:   if (k) dots->k = *k;
300:   PetscFunctionReturn(PETSC_SUCCESS);
301: }

303: static PetscErrorCode LMProductsGetUpdatedDiagonal(LMProducts dots, Vec *diagonal)
304: {
305:   PetscFunctionBegin;
306:   if (!dots->update_diagonal_global) {
307:     PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
308:     if (dots->m_local) PetscCall(VecCopy(dots->diagonal_dup, dots->diagonal_local));
309:     PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
310:     dots->update_diagonal_global = PETSC_TRUE;
311:   }
312:   if (diagonal) *diagonal = dots->diagonal_global;
313:   PetscFunctionReturn(PETSC_SUCCESS);
314: }

316: PETSC_INTERN PetscErrorCode LMProductsGetLocalDiagonal(LMProducts dots, Vec *D_local)
317: {
318:   PetscFunctionBegin;
319:   PetscCall(LMProductsGetUpdatedDiagonal(dots, NULL));
320:   PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
321:   *D_local = dots->diagonal_local;
322:   PetscFunctionReturn(PETSC_SUCCESS);
323: }

325: PETSC_INTERN PetscErrorCode LMProductsRestoreLocalDiagonal(LMProducts dots, Vec *D_local)
326: {
327:   PetscFunctionBegin;
328:   PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
329:   *D_local = NULL;
330:   PetscFunctionReturn(PETSC_SUCCESS);
331: }

333: PETSC_INTERN PetscErrorCode LMProductsGetNextColumn(LMProducts dots, Vec *col)
334: {
335:   PetscFunctionBegin;
336:   PetscCheck(dots->block_type != LMBLOCK_DIAGONAL, PETSC_COMM_SELF, PETSC_ERR_SUP, "Asking for column of diagonal products");
337:   PetscCall(MatDenseGetColumnVecWrite(dots->full, dots->k % dots->m, col));
338:   PetscFunctionReturn(PETSC_SUCCESS);
339: }

341: PETSC_INTERN PetscErrorCode LMProductsRestoreNextColumn(LMProducts dots, Vec *col)
342: {
343:   PetscFunctionBegin;
344:   PetscCall(MatDenseRestoreColumnVecWrite(dots->full, dots->k % dots->m, col));
345:   dots->k++;
346:   PetscFunctionReturn(PETSC_SUCCESS);
347: }

349: // copy conj(triu(G)) into tril(G)
350: PETSC_INTERN PetscErrorCode LMProductsMakeHermitian(Mat local, PetscInt oldest, PetscInt next)
351: {
352:   PetscInt m;

354:   PetscFunctionBegin;
355:   PetscCall(MatGetLocalSize(local, &m, NULL));
356:   if (m) {
357:     // TODO: implement on device?
358:     PetscScalar *a;
359:     PetscInt     lda;

361:     PetscCall(MatDenseGetLDA(local, &lda));
362:     PetscCall(MatDenseGetArray(local, &a));
363:     for (PetscInt j_ = oldest; j_ < next; j_++) {
364:       PetscInt j = j_ % m;

366:       a[j + j * lda] = PetscRealPart(a[j + j * lda]);
367:       for (PetscInt i_ = j_ + 1; i_ < next; i_++) {
368:         PetscInt i = i_ % m;

370:         a[i + j * lda] = PetscConj(a[j + i * lda]);
371:       }
372:     }
373:   }
374:   PetscFunctionReturn(PETSC_SUCCESS);
375: }

377: PETSC_INTERN PetscErrorCode LMProductsSolve(LMProducts dots, PetscInt oldest, PetscInt next, Vec b, Vec x, PetscBool hermitian_transpose)
378: {
379:   PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
380:   PetscInt dots_next   = dots->k;
381:   Mat      local;
382:   Vec      diag = NULL;

384:   PetscFunctionBegin;
385:   PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
386:   if (oldest >= next) PetscFunctionReturn(PETSC_SUCCESS);
387:   PetscCall(PetscLogEventBegin(LMPROD_Solve, NULL, NULL, NULL, NULL));
388:   if (!dots->rhs_local) PetscCall(VecCreateLocalVector(b, &dots->rhs_local));
389:   if (!dots->lhs_local) PetscCall(VecDuplicate(dots->rhs_local, &dots->lhs_local));
390:   switch (dots->block_type) {
391:   case LMBLOCK_DIAGONAL:
392:     PetscCall(LMProductsGetUpdatedDiagonal(dots, &diag));
393:     PetscCall(VecDSVCyclic(hermitian_transpose, oldest, next, diag, b, x));
394:     break;
395:   case LMBLOCK_UPPER_TRIANGLE:
396:     PetscCall(MatSeqDenseTRSVCyclic(hermitian_transpose, oldest, next, dots->full, b, x));
397:     break;
398:   default: {
399:     PetscCall(MatDenseGetLocalMatrix(dots->full, &local));
400:     PetscCall(VecGetLocalVector(b, dots->rhs_local));
401:     PetscCall(VecGetLocalVector(x, dots->lhs_local));
402:     if (dots->m_local) {
403:       if (!hermitian_transpose) {
404:         PetscCall(MatSolve(local, dots->rhs_local, dots->lhs_local));
405:       } else {
406:         Vec rhs_conj = dots->rhs_local;

408:         if (PetscDefined(USE_COMPLEX)) {
409:           PetscCall(VecDuplicate(dots->rhs_local, &rhs_conj));
410:           PetscCall(VecCopy(dots->rhs_local, rhs_conj));
411:           PetscCall(VecConjugate(rhs_conj));
412:         }
413:         PetscCall(MatSolveTranspose(local, rhs_conj, dots->lhs_local));
414:         if (PetscDefined(USE_COMPLEX)) {
415:           PetscCall(VecConjugate(dots->lhs_local));
416:           PetscCall(VecDestroy(&rhs_conj));
417:         }
418:       }
419:     }
420:     if (x != b) PetscCall(VecRestoreLocalVector(x, dots->lhs_local));
421:     PetscCall(VecRestoreLocalVector(b, dots->rhs_local));
422:   } break;
423:   }
424:   PetscCall(PetscLogEventEnd(LMPROD_Solve, NULL, NULL, NULL, NULL));
425:   PetscFunctionReturn(PETSC_SUCCESS);
426: }

428: PETSC_INTERN PetscErrorCode LMProductsMult(LMProducts dots, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y, PetscBool hermitian_transpose)
429: {
430:   PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
431:   PetscInt dots_next   = dots->k;
432:   Vec      diag        = NULL;

434:   PetscFunctionBegin;
435:   PetscCall(PetscLogEventBegin(LMPROD_Mult, NULL, NULL, NULL, NULL));
436:   PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
437:   switch (dots->block_type) {
438:   case LMBLOCK_DIAGONAL: {
439:     PetscCall(LMProductsGetUpdatedDiagonal(dots, &diag));
440:     PetscCall(VecDMVCyclic(hermitian_transpose, oldest, next, alpha, diag, x, beta, y));
441:   } break;
442:   case LMBLOCK_STRICT_UPPER_TRIANGLE: // the lower triangle has been zeroed, MatMult() is safe
443:   case LMBLOCK_FULL:
444:     PetscCall(MatSeqDenseGEMVCyclic(hermitian_transpose, oldest, next, alpha, dots->full, x, beta, y));
445:     break;
446:   default:
447:     SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Not implemented");
448:   }
449:   PetscCall(PetscLogEventEnd(LMPROD_Mult, NULL, NULL, NULL, NULL));
450:   PetscFunctionReturn(PETSC_SUCCESS);
451: }

453: PETSC_INTERN PetscErrorCode LMProductsMultHermitian(LMProducts dots, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
454: {
455:   PetscInt dots_oldest = PetscMax(0, dots->k - dots->m);
456:   PetscInt dots_next   = dots->k;

458:   PetscFunctionBegin;
459:   PetscCheck(oldest >= dots_oldest && next <= dots_next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Invalid indices");
460:   if (dots->block_type == LMBLOCK_DIAGONAL) PetscCall(LMProductsMult(dots, oldest, next, alpha, x, beta, y, PETSC_FALSE));
461:   else {
462:     PetscCall(PetscLogEventBegin(LMPROD_Mult, NULL, NULL, NULL, NULL));
463:     PetscCall(MatSeqDenseHEMVCyclic(oldest, next, alpha, dots->full, x, beta, y));
464:     PetscCall(PetscLogEventEnd(LMPROD_Mult, NULL, NULL, NULL, NULL));
465:   }
466:   PetscFunctionReturn(PETSC_SUCCESS);
467: }

469: PETSC_INTERN PetscErrorCode LMProductsReset(LMProducts dots)
470: {
471:   PetscFunctionBegin;
472:   if (dots) {
473:     dots->k              = 0;
474:     dots->operator_id    = 0;
475:     dots->operator_state = 0;
476:     if (dots->full) {
477:       Mat full_local;

479:       PetscCall(MatDenseGetLocalMatrix(dots->full, &full_local));
480:       PetscCall(MatSetUnfactored(full_local));
481:       PetscCall(MatZeroEntries(full_local));
482:     }
483:     if (dots->diagonal_global) PetscCall(VecZeroEntries(dots->diagonal_dup));
484:     if (dots->diagonal_dup) PetscCall(VecZeroEntries(dots->diagonal_dup));
485:   }
486:   PetscFunctionReturn(PETSC_SUCCESS);
487: }

489: PETSC_INTERN PetscErrorCode LMProductsGetDiagonalValue(LMProducts dots, PetscInt i, PetscScalar *v)
490: {
491:   PetscFunctionBegin;
492:   PetscInt oldest = PetscMax(0, dots->k - dots->m);
493:   PetscInt next   = dots->k;
494:   PetscInt idx    = i % dots->m;
495:   PetscCheck(i >= oldest && i < next, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Inserting value %d out of range [%d, %d)", (int)i, (int)oldest, (int)next);
496:   PetscCall(VecGetValues(dots->diagonal_dup, 1, &idx, v));
497:   PetscFunctionReturn(PETSC_SUCCESS);
498: }

500: PETSC_INTERN PetscErrorCode LMProductsInsertNextDiagonalValue(LMProducts dots, PetscInt i, PetscScalar v)
501: {
502:   PetscInt idx = i % dots->m;

504:   PetscFunctionBegin;
505:   PetscCheck(i == dots->k, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is not the next index (%" PetscInt_FMT ")", i, dots->k);
506:   PetscCall(VecSetValue(dots->diagonal_dup, idx, v, INSERT_VALUES));
507:   if (dots->update_diagonal_global) {
508:     PetscScalar *array;
509:     PetscMemType memtype;

511:     PetscCall(VecGetArrayAndMemType(dots->diagonal_global, &array, &memtype));
512:     if (dots->m_local > 0) {
513:       if (PetscMemTypeHost(memtype)) {
514:         array[idx] = v;
515:         PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
516:       } else {
517:         PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
518:         PetscCall(VecGetLocalVector(dots->diagonal_global, dots->diagonal_local));
519:         if (dots->m_local) PetscCall(VecCopy(dots->diagonal_dup, dots->diagonal_local));
520:         PetscCall(VecRestoreLocalVector(dots->diagonal_global, dots->diagonal_local));
521:       }
522:     } else {
523:       PetscCall(VecRestoreArrayAndMemType(dots->diagonal_global, &array));
524:     }
525:   }
526:   dots->k++;
527:   PetscFunctionReturn(PETSC_SUCCESS);
528: }

530: PETSC_INTERN PetscErrorCode LMProductsOnesOnUnusedDiagonal(Mat A, PetscInt oldest, PetscInt next)
531: {
532:   PetscInt m;
533:   Mat      sub;

535:   PetscFunctionBegin;
536:   PetscCall(MatGetSize(A, &m, NULL));
537:   // we could handle the general case but this is the only case used by MatLMVM
538:   PetscCheck((next < m && oldest == 0) || next - oldest == m, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "General case not implemented");
539:   if (next - oldest == m) PetscFunctionReturn(PETSC_SUCCESS); // nothing to do if all entries are used
540:   PetscCall(MatDenseGetSubMatrix(A, next, m, next, m, &sub));
541:   PetscCall(MatShift(sub, 1.0));
542:   PetscCall(MatDenseRestoreSubMatrix(A, &sub));
543:   PetscFunctionReturn(PETSC_SUCCESS);
544: }