Actual source code: cd_utils.c
1: #include <../src/ksp/ksp/utils/lmvm/dense/denseqn.h>
2: #include <petscblaslapack.h>
3: #include <petscmat.h>
4: #include <petscsys.h>
5: #include <petscsystypes.h>
6: #include <petscis.h>
7: #include <petscoptions.h>
8: #include <petscdevice.h>
9: #include <petsc/private/deviceimpl.h>
11: const char *const MatLMVMDenseTypes[] = {"reorder", "inplace", "MatLMVMDenseType", "MAT_LMVM_DENSE_", NULL};
13: PETSC_INTERN PetscErrorCode VecCyclicShift(Mat B, Vec X, PetscInt d, Vec cyclic_work_vec)
14: {
15: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
16: PetscInt m = lmvm->m;
17: PetscInt n;
18: const PetscScalar *src;
19: PetscScalar *dest;
20: PetscMemType src_memtype;
21: PetscMemType dest_memtype;
23: PetscFunctionBegin;
24: PetscCall(VecGetLocalSize(X, &n));
25: if (!cyclic_work_vec) PetscCall(VecDuplicate(X, &cyclic_work_vec));
26: PetscCall(VecCopy(X, cyclic_work_vec));
27: PetscCall(VecGetArrayReadAndMemType(cyclic_work_vec, &src, &src_memtype));
28: PetscCall(VecGetArrayWriteAndMemType(X, &dest, &dest_memtype));
29: if (n == 0) { /* no work on this process */
30: PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
31: PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
32: PetscFunctionReturn(PETSC_SUCCESS);
33: }
34: PetscAssert(src_memtype == dest_memtype, PETSC_COMM_SELF, PETSC_ERR_PLIB, "memtype of duplicate does not match");
35: if (PetscMemTypeHost(src_memtype)) {
36: PetscCall(PetscArraycpy(dest, &src[d], m - d));
37: PetscCall(PetscArraycpy(&dest[m - d], src, d));
38: } else {
39: PetscDeviceContext dctx;
41: PetscCall(PetscDeviceContextGetCurrentContext(&dctx));
42: PetscCall(PetscDeviceRegisterMemory(dest, dest_memtype, m * sizeof(*dest)));
43: PetscCall(PetscDeviceRegisterMemory(src, src_memtype, m * sizeof(*src)));
44: PetscCall(PetscDeviceArrayCopy(dctx, dest, &src[d], m - d));
45: PetscCall(PetscDeviceArrayCopy(dctx, &dest[m - d], src, d));
46: }
47: PetscCall(VecRestoreArrayWriteAndMemType(X, &dest));
48: PetscCall(VecRestoreArrayReadAndMemType(cyclic_work_vec, &src));
49: PetscFunctionReturn(PETSC_SUCCESS);
50: }
52: static inline PetscInt recycle_index(PetscInt m, PetscInt idx)
53: {
54: return idx % m;
55: }
57: static inline PetscInt oldest_update(PetscInt m, PetscInt idx)
58: {
59: return PetscMax(0, idx - m);
60: }
62: PETSC_INTERN PetscErrorCode VecRecycleOrderToHistoryOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
63: {
64: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
65: PetscInt m = lmvm->m;
66: PetscInt oldest_index;
68: PetscFunctionBegin;
69: oldest_index = recycle_index(m, oldest_update(m, num_updates));
70: if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in history order */
71: PetscCall(VecCyclicShift(B, X, oldest_index, cyclic_work_vec));
72: PetscFunctionReturn(PETSC_SUCCESS);
73: }
75: PETSC_INTERN PetscErrorCode VecHistoryOrderToRecycleOrder(Mat B, Vec X, PetscInt num_updates, Vec cyclic_work_vec)
76: {
77: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
78: PetscInt m = lmvm->m;
79: PetscInt oldest_index;
81: PetscFunctionBegin;
82: oldest_index = recycle_index(m, oldest_update(m, num_updates));
83: if (oldest_index == 0) PetscFunctionReturn(PETSC_SUCCESS); /* vector is already in recycle order */
84: PetscCall(VecCyclicShift(B, X, m - oldest_index, cyclic_work_vec));
85: PetscFunctionReturn(PETSC_SUCCESS);
86: }
88: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace_Internal(MatLMVMDenseType type, PetscMemType memtype, PetscBool hermitian_transpose, PetscInt m, PetscInt oldest, PetscInt next, const PetscScalar A[], PetscInt lda, PetscScalar x[], PetscInt stride)
89: {
90: PetscInt oldest_index = oldest % m;
91: PetscInt next_index = (next - 1) % m + 1;
92: PetscInt N = next - oldest;
94: PetscFunctionBegin;
95: /* if oldest_index == 0, the two strategies are equivalent, redirect to the simpler one */
96: if (oldest_index == 0) type = MAT_LMVM_DENSE_REORDER;
97: switch (type) {
98: case MAT_LMVM_DENSE_REORDER:
99: if (PetscMemTypeHost(memtype)) {
100: PetscBLASInt n, lda_blas, one = 1;
101: PetscCall(PetscBLASIntCast(N, &n));
102: PetscCall(PetscBLASIntCast(lda, &lda_blas));
103: PetscCallBLAS("BLAStrsv", BLAStrsv_("U", hermitian_transpose ? "C" : "N", "NotUnitTriangular", &n, A, &lda_blas, x, &one));
104: PetscCall(PetscLogFlops(1.0 * n * n));
105: #if defined(PETSC_HAVE_CUPM)
106: } else if (PetscMemTypeDevice(memtype)) {
107: PetscCall(MatUpperTriangularSolveInPlace_CUPM(hermitian_transpose, N, A, lda, x, 1));
108: #endif
109: } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
110: break;
111: case MAT_LMVM_DENSE_INPLACE:
112: if (PetscMemTypeHost(memtype)) {
113: PetscBLASInt n_old, n_new, lda_blas, one = 1;
114: PetscScalar minus_one = -1.0;
115: PetscScalar sone = 1.0;
116: PetscCall(PetscBLASIntCast(m - oldest_index, &n_old));
117: PetscCall(PetscBLASIntCast(next_index, &n_new));
118: PetscCall(PetscBLASIntCast(lda, &lda_blas));
119: if (!hermitian_transpose) {
120: if (n_new > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
121: if (n_new > 0 && n_old > 0) PetscCallBLAS("BLASgemv", BLASgemv_("N", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, x, &one, &sone, &x[oldest_index], &one));
122: if (n_old > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "N", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
123: } else {
124: if (n_old > 0) {
125: PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_old, &A[oldest_index * (lda + 1)], &lda_blas, &x[oldest_index], &one));
126: if (n_new > 0 && n_old > 0) PetscCallBLAS("BLASgemv", BLASgemv_("C", &n_old, &n_new, &minus_one, &A[oldest_index], &lda_blas, &x[oldest_index], &one, &sone, x, &one));
127: }
128: if (n_new > 0) PetscCallBLAS("BLAStrsv", BLAStrsv_("U", "C", "NotUnitTriangular", &n_new, A, &lda_blas, x, &one));
129: }
130: PetscCall(PetscLogFlops(1.0 * N * N));
131: #if defined(PETSC_HAVE_CUPM)
132: } else if (PetscMemTypeDevice(memtype)) {
133: PetscCall(MatUpperTriangularSolveInPlaceCyclic_CUPM(hermitian_transpose, m, oldest, next, A, lda, x, stride));
134: #endif
135: } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Unsupported memtype");
136: break;
137: default:
138: PetscUnreachable();
139: }
140: PetscFunctionReturn(PETSC_SUCCESS);
141: }
143: PETSC_INTERN PetscErrorCode MatUpperTriangularSolveInPlace(Mat B, Mat Amat, Vec X, PetscBool hermitian_transpose, PetscInt num_updates, MatLMVMDenseType strategy)
144: {
145: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
146: PetscInt m = lmvm->m;
147: PetscInt h, local_n;
148: PetscInt lda;
149: PetscScalar *x;
150: PetscMemType memtype_r, memtype_x;
151: const PetscScalar *A;
153: PetscFunctionBegin;
154: h = num_updates - oldest_update(m, num_updates);
155: if (!h) PetscFunctionReturn(PETSC_SUCCESS);
156: PetscCall(VecGetLocalSize(X, &local_n));
157: PetscCall(VecGetArrayAndMemType(X, &x, &memtype_x));
158: PetscCall(MatDenseGetArrayReadAndMemType(Amat, &A, &memtype_r));
159: if (!local_n) {
160: PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
161: PetscCall(VecRestoreArrayAndMemType(X, &x));
162: PetscFunctionReturn(PETSC_SUCCESS);
163: }
164: PetscAssert(memtype_x == memtype_r, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Incompatible device pointers");
165: PetscCall(MatDenseGetLDA(Amat, &lda));
166: PetscCall(MatUpperTriangularSolveInPlace_Internal(strategy, memtype_x, hermitian_transpose, m, oldest_update(m, num_updates), num_updates, A, lda, x, 1));
167: PetscCall(VecRestoreArrayWriteAndMemType(X, &x));
168: PetscCall(MatDenseRestoreArrayReadAndMemType(Amat, &A));
169: PetscFunctionReturn(PETSC_SUCCESS);
170: }
172: /* Shifts R[end-m_keep:end,end-m_keep:end] to R[0:m_keep, 0:m_keep] */
174: PETSC_INTERN PetscErrorCode MatMove_LR3(Mat B, Mat R, PetscInt m_keep)
175: {
176: Mat_LMVM *lmvm = (Mat_LMVM *)B->data;
177: Mat_DQN *lqn = (Mat_DQN *)lmvm->ctx;
178: PetscInt M;
179: Mat mat_local, local_sub, local_temp, temp_sub;
181: PetscFunctionBegin;
182: if (!lqn->temp_mat) PetscCall(MatDuplicate(R, MAT_SHARE_NONZERO_PATTERN, &lqn->temp_mat));
183: PetscCall(MatGetLocalSize(R, &M, NULL));
184: if (M == 0) PetscFunctionReturn(PETSC_SUCCESS);
186: PetscCall(MatDenseGetLocalMatrix(R, &mat_local));
187: PetscCall(MatDenseGetLocalMatrix(lqn->temp_mat, &local_temp));
188: PetscCall(MatDenseGetSubMatrix(mat_local, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &local_sub));
189: PetscCall(MatDenseGetSubMatrix(local_temp, lmvm->m - m_keep, lmvm->m, lmvm->m - m_keep, lmvm->m, &temp_sub));
190: PetscCall(MatCopy(local_sub, temp_sub, SAME_NONZERO_PATTERN));
191: PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
192: PetscCall(MatDenseGetSubMatrix(mat_local, 0, m_keep, 0, m_keep, &local_sub));
193: PetscCall(MatCopy(temp_sub, local_sub, SAME_NONZERO_PATTERN));
194: PetscCall(MatDenseRestoreSubMatrix(mat_local, &local_sub));
195: PetscCall(MatDenseRestoreSubMatrix(local_temp, &temp_sub));
196: PetscFunctionReturn(PETSC_SUCCESS);
197: }