Actual source code: cd_utils.c

  1: #include <../src/ksp/ksp/utils/lmvm/dense/denseqn.h>
  2: #include <petscblaslapack.h>
  3: #include <petscmat.h>
  4: #include <petscsys.h>
  5: #include <petscsystypes.h>
  6: #include <petscis.h>
  7: #include <petscoptions.h>
  8: #include <petscdevice.h>
  9: #include <petsc/private/deviceimpl.h>

 11: const char *const MatLMVMDenseTypes[] = {"reorder", "inplace", "MatLMVMDenseType", "MAT_LMVM_DENSE_", NULL};

 13: PETSC_INTERN PetscErrorCode VecCyclicShift(Mat B, Vec X, PetscInt d, Vec cyclic_work_vec)
 14: {
 15:   Mat_LMVM          *lmvm = (Mat_LMVM *)B->data;
 16:   PetscInt           m    = lmvm->m;
 17:   PetscInt           n;
 18:   const PetscScalar *src;
 19:   PetscScalar       *dest;
 20:   PetscMemType       src_memtype;
 21:   PetscMemType       dest_memtype;

 23:   PetscFunctionBegin;
 24:   PetscCall(VecGetLocalSize(X, &n));
 25:   if (!cyclic_work_vec) PetscCall(VecDuplicate(X, &cyclic_work_vec));
 26:   PetscCall(VecCopy(X, cyclic_work_vec));
 27:   PetscCall(VecGetArrayReadAndMemType(cyclic_work_vec, &src, &src_memtype));
 28:   PetscCall(VecGetArrayWriteAndMemType(X, &dest, &dest_memtype));
 29:   if (n == 0) { /* no work on this process */
 30:     PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
 31:     PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
 32:     PetscFunctionReturn(PETSC_SUCCESS);
 33:   }
 34:   PetscAssert(src_memtype == dest_memtype, PETSC_COMM_SELF, PETSC_ERR_PLIB, "memtype of duplicate does not match");
 35:   if (PetscMemTypeHost(src_memtype)) {
 36:     PetscCall(PetscArraycpy(dest, &src[d], m - d));
 37:     PetscCall(PetscArraycpy(&dest[m - d], src, d));
 38:   } else {
 39:     PetscDeviceContext dctx;

 41:     PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
 42:     PetscCall(PetscDeviceRegisterMemory(dest, dest_memtype, m * sizeof(*dest)));
 43:     PetscCall(PetscDeviceRegisterMemory(src, src_memtype, m * sizeof(*src)));
 44:     PetscCall(PetscDeviceArrayCopy(dctx, dest, &src[d], m - d));
 45:     PetscCall(PetscDeviceArrayCopy(dctx, &dest[m - d], src, d));
 46:   }
 47:   PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
 48:   PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
 49:   PetscFunctionReturn(PETSC_SUCCESS);
 50: }

 52: static inline PetscInt recycle_index(PetscInt m, PetscInt idx)
 53: {
 54:   return idx % m;
 55: }

 57: static inline PetscInt oldest_update(PetscInt m, PetscInt idx)
 58: {
 59:   return PetscMax(0, idx - m);
 60: }

 62: PETSC_INTERN PetscErrorCode VecRecycleOrderToHistoryOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
 63: {
 64:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
 65:   PetscInt  m    = lmvm->m;
 66:   PetscInt  oldest_index;

 68:   PetscFunctionBegin;
 69:   oldest_index = recycle_index(m, oldest_update(m, num_updates));
 70:   if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in history order */
 71:   PetscCall(VecCyclicShift(B, X, oldest_index, cyclic_work_vec));
 72:   PetscFunctionReturn(PETSC_SUCCESS);
 73: }

 75: PETSC_INTERN PetscErrorCode VecHistoryOrderToRecycleOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
 76: {
 77:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
 78:   PetscInt  m    = lmvm->m;
 79:   PetscInt  oldest_index;

 81:   PetscFunctionBegin;
 82:   oldest_index = recycle_index(m, oldest_update(m, num_updates));
 83:   if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in recycle order */
 84:   PetscCall(VecCyclicShift(B, X, m - oldest_index, cyclic_work_vec));
 85:   PetscFunctionReturn(PETSC_SUCCESS);
 86: }

 88: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_Internal(MatLMVMDenseType type, PetscMemType memtype, PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
 89: {
 90:   PetscInt oldest_index = oldest % m;
 91:   PetscInt next_index   = (next - 1) % m + 1;
 92:   PetscInt N            = next - oldest;

 94:   PetscFunctionBegin;
 95:   /* if oldest_index == 0, the two strategies are equivalent, redirect to the simpler one */
 96:   if (oldest_index == 0) type = MAT_LMVM_DENSE_REORDER;
 97:   switch (type) {
 98:   case MAT_LMVM_DENSE_REORDER:
 99:     if (PetscMemTypeHost(memtype)) {
100:       PetscBLASInt n, lda_blas, one = 1;
101:       PetscCall(PetscBLASIntCast(N, &n));
102:       PetscCall(PetscBLASIntCast(lda, &lda_blas));
103:       PetscCallBLAS("BLAStrsv", BLAStrsv_("U", hermitian_transpose ? "C" : "N", "NotUnitTriangular", &n, A, &lda_blas, x, &one));
104:       PetscCall(PetscLogFlops(1.0 * n * n));
105: #if defined(PETSC_HAVE_CUPM)
106:     } else if (PetscMemTypeDevice(memtype)) {
107:       PetscCall(MatUpperTriangularSolveInPlace_CUPM(hermitian_transpose, N, A, lda, x, 1));
108: #endif
109:     } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
110:     break;
111:   case MAT_LMVM_DENSE_INPLACE:
112:     if (PetscMemTypeHost(memtype)) {
113:       PetscBLASInt n_old, n_new, lda_blas, one = 1;
114:       PetscScalar  minus_one = -1.0;
115:       PetscScalar  sone      = 1.0;
116:       PetscCall(PetscBLASIntCast(m - oldest_index, &n_old));
117:       PetscCall(PetscBLASIntCast(next_index, &n_new));
118:       PetscCall(PetscBLASIntCast(lda, &lda_blas));
119:       if (!hermitian_transpose) {
120:         if (n_new > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
121:         if (n_new > 0 && n_old > 0) PetscCallBLAS("BLASgemv", BLASgemv_("N", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, x, &one, &sone, &x[oldest_index], &one));
122:         if (n_old > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
123:       } else {
124:         if (n_old > 0) {
125:           PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
126:           if (n_new > 0 && n_old > 0) PetscCallBLAS("BLASgemv", BLASgemv_("C", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, &x[oldest_index], &one, &sone, x, &one));
127:         }
128:         if (n_new > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
129:       }
130:       PetscCall(PetscLogFlops(1.0 * N * N));
131: #if defined(PETSC_HAVE_CUPM)
132:     } else if (PetscMemTypeDevice(memtype)) {
133:       PetscCall(MatUpperTriangularSolveInPlaceCyclic_CUPM(hermitian_transpose, m, oldest, next, A, lda, x, stride));
134: #endif
135:     } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
136:     break;
137:   default:
138:     PetscUnreachable();
139:   }
140:   PetscFunctionReturn(PETSC_SUCCESS);
141: }

143: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace(Mat B, Mat Amat, Vec X, PetscBool hermitian_transpose, PetscInt num_updates, MatLMVMDenseType strategy)
144: {
145:   Mat_LMVM          *lmvm = (Mat_LMVM *)B->data;
146:   PetscInt           m    = lmvm->m;
147:   PetscInt           h, local_n;
148:   PetscInt           lda;
149:   PetscScalar       *x;
150:   PetscMemType       memtype_r, memtype_x;
151:   const PetscScalar *A;

153:   PetscFunctionBegin;
154:   h = num_updates - oldest_update(m, num_updates);
155:   if (!h) PetscFunctionReturn(PETSC_SUCCESS);
156:   PetscCall(VecGetLocalSize(X, &local_n));
157:   PetscCall(VecGetArrayAndMemType(X, &x, &memtype_x));
158:   PetscCall(MatDenseGetArrayReadAndMemType(Amat, &A, &memtype_r));
159:   if (!local_n) {
160:     PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
161:     PetscCall(VecRestoreArrayAndMemType(X, &x));
162:     PetscFunctionReturn(PETSC_SUCCESS);
163:   }
164:   PetscAssert(memtype_x == memtype_r, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Incompatible device pointers");
165:   PetscCall(MatDenseGetLDA(Amat, &lda));
166:   PetscCall(MatUpperTriangularSolveInPlace_Internal(strategy, memtype_x, hermitian_transpose, m, oldest_update(m, num_updates), num_updates, A, lda, x, 1));
167:   PetscCall(VecRestoreArrayWriteAndMemType(X, &x));
168:   PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
169:   PetscFunctionReturn(PETSC_SUCCESS);
170: }

172: /* Shifts R[end-m_keep:end,end-m_keep:end] to R[0:m_keep, 0:m_keep] */

174: PETSC_INTERN PetscErrorCode MatMove_LR3(Mat B, Mat R, PetscInt m_keep)
175: {
176:   Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
177:   Mat_DQN  *lqn  = (Mat_DQN *)lmvm->ctx;
178:   PetscInt  M;
179:   Mat       mat_local, local_sub, local_temp, temp_sub;

181:   PetscFunctionBegin;
182:   if (!lqn->temp_mat) PetscCall(MatDuplicate(R, MAT_SHARE_NONZERO_PATTERN, &lqn->temp_mat));
183:   PetscCall(MatGetLocalSize(R, &M, NULL));
184:   if (M == 0) PetscFunctionReturn(PETSC_SUCCESS);

186:   PetscCall(MatDenseGetLocalMatrix(R, &mat_local));
187:   PetscCall(MatDenseGetLocalMatrix(lqn->temp_mat, &local_temp));
188:   PetscCall(MatDenseGetSubMatrix(mat_local, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &local_sub));
189:   PetscCall(MatDenseGetSubMatrix(local_temp, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &temp_sub));
190:   PetscCall(MatCopy(local_sub, temp_sub, SAME_NONZERO_PATTERN));
191:   PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
192:   PetscCall(MatDenseGetSubMatrix(mat_local, 0, m_keep, 0, m_keep, &local_sub));
193:   PetscCall(MatCopy(temp_sub, local_sub, SAME_NONZERO_PATTERN));
194:   PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
195:   PetscCall(MatDenseRestoreSubMatrix(local_temp, &temp_sub));
196:   PetscFunctionReturn(PETSC_SUCCESS);
197: }