Actual source code: letkf_local_analysis.kokkos.cxx
1: #include "../src/ml/da/impls/ensemble/letkf/letkf.h"
2: #include <Kokkos_Core.hpp>
3: #include <KokkosBlas.hpp>
5: #if defined(KOKKOS_ENABLE_CUDA)
6: #include <cusolverDn.h>
7: #include <cuda_runtime.h>
8: #include <petscdevice_cuda.h>
9: #elif defined(KOKKOS_ENABLE_HIP)
10: #include <rocsolver/rocsolver.h>
11: #include <hip/hip_runtime.h>
12: #include <petscdevice_hip.h>
13: #elif defined(KOKKOS_ENABLE_SYCL)
14: #include <oneapi/mkl.hpp>
15: #include <sycl/sycl.hpp>
16: #endif
18: /* ========================================================================== */
19: /* Batched Eigendecomposition for LETKF */
20: /* ========================================================================== */
22: /* Structure to hold reusable workspace for eigensolvers */
23: struct EigenWorkspace {
24: /* Tracking for reuse */
25: PetscInt max_chunk_size;
26: PetscInt m;
27: PetscInt n_obs_vertex;
29: /* Persistent Kokkos Views */
30: using exec_space = Kokkos::DefaultExecutionSpace;
31: using view_3d = Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, exec_space>;
32: using view_2d = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>;
34: view_3d Z_batch;
35: view_3d S_batch;
36: view_3d T_batch;
37: view_3d V_batch;
38: view_2d Lambda_batch;
39: view_3d T_sqrt_batch;
40: view_2d w_batch;
41: view_2d delta_batch;
42: view_2d y_batch;
43: view_2d y_mean_batch;
44: view_2d r_inv_sqrt_batch;
45: view_2d temp1_batch;
46: view_2d temp2_batch;
47: view_2d inv_sqrt_lambda_batch;
49: /* Host workspace */
50: PetscScalar *all_v;
51: PetscReal *all_lambda;
52: PetscScalar *all_work;
53: #if defined(PETSC_USE_COMPLEX)
54: PetscReal *all_rwork;
55: #endif
56: PetscBLASInt lwork;
57: PetscBLASInt n_blas;
59: /* Device workspace */
60: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
61: #if defined(KOKKOS_ENABLE_CUDA)
62: syevjInfo_t syevj_params;
63: PetscScalar *d_work;
64: int *d_info;
65: PetscScalar *d_A_contig;
66: PetscScalar *d_W_contig;
67: int lwork_device;
68: #elif defined(KOKKOS_ENABLE_HIP)
69: PetscScalar *d_work;
70: int *d_info;
71: PetscScalar *d_A_contig;
72: PetscScalar *d_W_contig;
73: int lwork_device;
74: #elif defined(KOKKOS_ENABLE_SYCL)
75: PetscScalar *d_work;
76: int *d_info;
77: PetscScalar *d_A_contig;
78: PetscScalar *d_W_contig;
79: int lwork_device;
80: #endif
81: #endif
83: EigenWorkspace() : max_chunk_size(0), m(0), n_obs_vertex(0), all_v(nullptr), all_lambda(nullptr), all_work(nullptr)
84: {
85: #if defined(PETSC_USE_COMPLEX)
86: all_rwork = nullptr;
87: #endif
88: #if defined(KOKKOS_ENABLE_CUDA)
89: d_work = nullptr;
90: d_info = nullptr;
91: d_A_contig = nullptr;
92: d_W_contig = nullptr;
93: syevj_params = nullptr;
94: #elif defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
95: d_work = nullptr;
96: d_info = nullptr;
97: d_A_contig = nullptr;
98: d_W_contig = nullptr;
99: #endif
100: }
101: };
103: /*
104: BatchedEigenSolve_Host - Compute eigendecomposition for a batch of symmetric matrices (CPU version)
106: Input Parameters:
107: + T_batch - batch of symmetric matrices (n_batch x n_size x n_size)
108: . n_batch - number of matrices in the batch
109: - n_size - size of each matrix (m x m)
110: - work - reusable workspace structure
112: Output Parameters:
113: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
114: - V_batch - eigenvectors for each matrix (n_batch x n_size x n_size)
116: Notes:
117: Uses LAPACK's syev routine to compute eigendecomposition sequentially on host.
118: */
119: #if !defined(KOKKOS_ENABLE_CUDA) && !defined(KOKKOS_ENABLE_HIP) && !defined(KOKKOS_ENABLE_SYCL)
120: #include <petscblaslapack.h>
121: static PetscErrorCode BatchedEigenSolve_Host(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, EigenWorkspace *work)
122: {
123: PetscFunctionBegin;
124: /* Create host mirrors and copy data in one operation */
125: /* This is required for HIP+complex where create_mirror_view + deep_copy fails */
126: auto T_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), T_batch);
127: auto Lambda_host = Kokkos::create_mirror_view(Kokkos::HostSpace(), Lambda_batch);
128: auto V_host = Kokkos::create_mirror_view(Kokkos::HostSpace(), V_batch);
130: /* Use pre-allocated workspace */
131: PetscScalar *all_v = work->all_v;
132: PetscReal *all_lambda = work->all_lambda;
133: PetscScalar *all_work = work->all_work;
134: PetscBLASInt lwork = work->lwork;
135: PetscBLASInt n_blas = work->n_blas;
136: #if defined(PETSC_USE_COMPLEX)
137: PetscReal *all_rwork = work->all_rwork;
138: #endif
140: /* Process each matrix in parallel on host using LAPACK */
141: Kokkos::parallel_for(
142: "BatchedEigenSolve_Host", Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
143: PetscBLASInt n = n_blas;
144: PetscBLASInt lda = n;
145: PetscBLASInt info;
146: PetscBLASInt lw = lwork;
148: /* Pointers for this matrix */
149: PetscScalar *v_ptr = all_v + i * n_size * n_size;
150: PetscReal *lambda_ptr = all_lambda + i * n_size;
151: PetscScalar *work_ptr = all_work + i * lwork;
152: #if defined(PETSC_USE_COMPLEX)
153: PetscReal *rwork_ptr = all_rwork + i * (3 * n_size - 2);
154: #endif
156: /* Copy T_host(i, :, :) to v_ptr (column-major) */
157: for (PetscInt j = 0; j < n_size; j++) {
158: for (PetscInt k = 0; k < n_size; k++) v_ptr[k + j * n_size] = T_host(i, k, j);
159: }
161: /* Compute eigendecomposition: T = V * Lambda * V^T */
162: #if defined(PETSC_USE_COMPLEX)
163: LAPACKsyev_("V", "U", &n, v_ptr, &lda, lambda_ptr, work_ptr, &lw, rwork_ptr, &info);
164: #else
165: LAPACKsyev_("V", "U", &n, v_ptr, &lda, lambda_ptr, work_ptr, &lw, &info);
166: #endif
168: if (info != 0) {
169: /* We cannot return error code from lambda, so we just abort or ignore.
170: In production code, we should use a reduction to report errors. */
171: Kokkos::abort("LAPACK eigendecomposition failed in parallel region");
172: }
174: /* Copy results back to host views */
175: for (PetscInt j = 0; j < n_size; j++) {
176: Lambda_host(i, j) = (PetscScalar)lambda_ptr[j];
177: for (PetscInt k = 0; k < n_size; k++) V_host(i, k, j) = v_ptr[k + j * n_size];
178: }
179: });
181: /* Copy results back to device */
182: Kokkos::deep_copy(Lambda_batch, Lambda_host);
183: Kokkos::deep_copy(V_batch, V_host);
184: PetscFunctionReturn(PETSC_SUCCESS);
185: }
186: #endif
188: /*
189: BatchedEigenSolve_Device - Compute eigendecomposition for a batch of symmetric matrices (Device version)
191: Input Parameters:
192: + T_batch - batch of symmetric matrices (n_batch x n_size x n_size)
193: . n_batch - number of matrices in the batch
194: - n_size - size of each matrix (m x m)
195: - device_handle - device-specific solver handle (cusolverDnHandle_t, rocblas_handle, or sycl::queue*)
196: - work - reusable workspace structure
198: Output Parameters:
199: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
200: - V_batch - eigenvectors for each matrix (n_batch x n_size x n_size)
202: Notes:
203: Uses vendor-specific batched symmetric eigensolvers:
204: - CUDA: cuSOLVER's syevjBatched
205: - HIP: rocSOLVER's rocsolver_dsyevj_batched
206: - SYCL: oneMKL's syevd_batch
207: */
208: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
209: #if defined(KOKKOS_ENABLE_CUDA)
210: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, cusolverDnHandle_t cusolverH, EigenWorkspace *work)
211: {
212: cusolverStatus_t cusolver_status;
214: PetscFunctionBegin;
215: /* Use pre-allocated workspace */
216: syevjInfo_t syevj_params = work->syevj_params;
217: PetscScalar *d_work = work->d_work;
218: int *d_info = work->d_info;
219: PetscScalar *d_A_contig = work->d_A_contig;
220: PetscScalar *d_W_contig = work->d_W_contig;
221: int lwork = work->lwork_device;
223: /* Copy T_batch to contiguous layout for cuSOLVER */
224: Kokkos::parallel_for(
225: "ReorganizeForCuSOLVER", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
226: for (int j = 0; j < n_size; j++) {
227: for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
228: }
229: });
230: Kokkos::fence();
232: /* Solve batched eigendecomposition */
233: #if defined(PETSC_USE_REAL_SINGLE)
234: cusolver_status = cusolverDnSsyevjBatched(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, n_size, d_A_contig, n_size, d_W_contig, d_work, lwork, d_info, syevj_params, n_batch);
235: #else
236: cusolver_status = cusolverDnDsyevjBatched(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, n_size, d_A_contig, n_size, d_W_contig, d_work, lwork, d_info, syevj_params, n_batch);
237: #endif
238: PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDn*syevjBatched failed");
240: /* Check info */
241: int *h_info;
242: PetscCall(PetscMalloc1(n_batch, &h_info));
243: PetscCallCUDA(cudaMemcpy(h_info, d_info, sizeof(int) * n_batch, cudaMemcpyDeviceToHost));
244: for (PetscInt i = 0; i < n_batch; i++) {
245: if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "cuSOLVER eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
246: }
247: PetscCall(PetscFree(h_info));
249: /* Copy results back from contiguous layout to V_batch */
250: Kokkos::parallel_for(
251: "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
252: for (int j = 0; j < n_size; j++) {
253: for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
254: Lambda_batch(i, j) = d_W_contig[i * n_size + j]; // CUDA-12.6 nvcc compiler hangs if we put this line before the V_batch loop
255: }
256: });
257: Kokkos::fence();
258: PetscFunctionReturn(PETSC_SUCCESS);
259: }
260: #elif defined(KOKKOS_ENABLE_HIP)
261: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, rocblas_handle rocblasH, EigenWorkspace *work)
262: {
263: PetscFunctionBegin;
264: /* Use pre-allocated workspace */
265: PetscScalar *d_work = work->d_work;
266: (void)d_work;
267: int *d_info = work->d_info;
268: PetscScalar *d_A_contig = work->d_A_contig;
269: PetscScalar *d_W_contig = work->d_W_contig;
271: /* Copy T_batch to contiguous layout for rocSOLVER */
272: Kokkos::parallel_for(
273: "ReorganizeForRocSOLVER", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
274: for (int j = 0; j < n_size; j++) {
275: for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
276: }
277: });
278: Kokkos::fence();
280: /* rocSOLVER doesn't have a native batched syevj, so we loop over batch */
281: /* Use rocsolver_dsyevd which is more efficient than calling syev in a loop */
282: #if defined(PETSC_USE_COMPLEX)
283: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Complex numbers not supported on HIP backend for LETKF");
284: #else
285: for (int i = 0; i < n_batch; i++) {
286: PetscScalar *A_ptr = d_A_contig + i * n_size * n_size;
287: PetscScalar *W_ptr = d_W_contig + i * n_size;
288: int *info_ptr = d_info + i;
289: rocblas_status hip_status;
291: #if defined(PETSC_USE_REAL_SINGLE)
292: hip_status = rocsolver_ssyevd(rocblasH, rocblas_evect_original, rocblas_fill_upper, n_size, A_ptr, n_size, W_ptr, d_work, info_ptr);
293: #else
294: hip_status = rocsolver_dsyevd(rocblasH, rocblas_evect_original, rocblas_fill_upper, n_size, A_ptr, n_size, W_ptr, d_work, info_ptr);
295: #endif
296: PetscCheck(hip_status == rocblas_status_success, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocsolver_*syevd failed for batch %" PetscInt_FMT, i);
297: }
298: #endif
300: /* Check info */
301: int *h_info;
302: PetscCall(PetscMalloc1(n_batch, &h_info));
303: PetscCallHIP(hipMemcpy(h_info, d_info, sizeof(int) * n_batch, hipMemcpyDeviceToHost));
304: for (PetscInt i = 0; i < n_batch; i++) {
305: if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocSOLVER eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
306: }
307: PetscCall(PetscFree(h_info));
309: /* Copy results back from contiguous layout to V_batch */
310: Kokkos::parallel_for(
311: "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
312: for (int j = 0; j < n_size; j++) {
313: for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
314: Lambda_batch(i, j) = d_W_contig[i * n_size + j];
315: }
316: });
317: Kokkos::fence();
318: PetscFunctionReturn(PETSC_SUCCESS);
319: }
320: #elif defined(KOKKOS_ENABLE_SYCL)
321: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, sycl::queue *q, EigenWorkspace *work)
322: {
323: PetscFunctionBegin;
324: /* Use pre-allocated workspace */
325: PetscScalar *d_work = work->d_work;
326: int *d_info = work->d_info;
327: PetscScalar *d_A_contig = work->d_A_contig;
328: PetscScalar *d_W_contig = work->d_W_contig;
330: /* Copy T_batch to contiguous layout for oneMKL */
331: Kokkos::parallel_for(
332: "ReorganizeForOneMKL", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
333: for (int j = 0; j < n_size; j++) {
334: for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
335: }
336: });
337: Kokkos::fence();
339: /* oneMKL doesn't have a native batched syevd, so we loop over batch */
340: /* Use oneapi::mkl::lapack::syevd which computes eigenvalues and eigenvectors */
341: for (int i = 0; i < n_batch; i++) {
342: PetscScalar *A_ptr = d_A_contig + i * n_size * n_size;
343: PetscScalar *W_ptr = d_W_contig + i * n_size;
344: // int *info_ptr = d_info + i;
346: try {
347: #if defined(PETSC_USE_REAL_SINGLE)
348: // oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device, info_ptr);
349: oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device);
350: #else
351: // oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device, info_ptr);
352: oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device);
353: #endif
354: q->wait();
355: } catch (sycl::exception const &e) {
356: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "oneMKL syevd failed for batch %d: %s", i, e.what());
357: }
358: }
360: /* Check info */
361: int *h_info;
362: PetscCall(PetscMalloc1(n_batch, &h_info));
363: q->memcpy(h_info, d_info, sizeof(int) * n_batch).wait();
364: for (PetscInt i = 0; i < n_batch; i++) {
365: if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "oneMKL eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
366: }
367: PetscCall(PetscFree(h_info));
369: /* Copy results back from contiguous layout to V_batch */
370: Kokkos::parallel_for(
371: "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
372: for (int j = 0; j < n_size; j++) {
373: for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
374: Lambda_batch(i, j) = d_W_contig[i * n_size + j];
375: }
376: });
377: Kokkos::fence();
378: PetscFunctionReturn(PETSC_SUCCESS);
379: }
380: #endif
381: #endif
383: /*
384: BatchedEigenSolve - Compute eigendecomposition for a batch of symmetric matrices
386: Input Parameters:
387: + T_batch - batch of symmetric matrices (n_batch x n_size x n_size)
388: . n_batch - number of matrices in the batch
389: - n_size - size of each matrix (m x m)
390: - device_handle - device-specific solver handle (only for device builds)
391: - work - reusable workspace structure
393: Output Parameters:
394: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
395: - V_batch - eigenvectors for each matrix (n_batch x n_size x n_size)
397: Notes:
398: Dispatcher function that calls the appropriate backend (Device or Host).
399: */
400: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
401: #if defined(KOKKOS_ENABLE_CUDA)
402: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, cusolverDnHandle_t device_handle, EigenWorkspace *work)
403: {
404: PetscFunctionBegin;
405: PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
406: PetscFunctionReturn(PETSC_SUCCESS);
407: }
408: #elif defined(KOKKOS_ENABLE_HIP)
409: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, rocblas_handle device_handle, EigenWorkspace *work)
410: {
411: PetscFunctionBegin;
412: PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
413: PetscFunctionReturn(PETSC_SUCCESS);
414: }
415: #elif defined(KOKKOS_ENABLE_SYCL)
416: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, sycl::queue *device_handle, EigenWorkspace *work)
417: {
418: PetscFunctionBegin;
419: PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
420: PetscFunctionReturn(PETSC_SUCCESS);
421: }
422: #endif
423: #else
424: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, EigenWorkspace *work)
425: {
426: PetscFunctionBegin;
427: PetscCall(BatchedEigenSolve_Host(T_batch, Lambda_batch, V_batch, n_batch, n_size, work));
428: PetscFunctionReturn(PETSC_SUCCESS);
429: }
430: #endif
432: /*
433: PetscDALETKFSetupLocalization_Kokkos - Prepares device views for localization matrix Q
434: */
435: PetscErrorCode PetscDALETKFSetupLocalization_Kokkos(PetscDA_LETKF *impl, Mat H)
436: {
437: PetscInt nrows;
439: PetscFunctionBegin;
440: PetscCheck(impl->Q, PETSC_COMM_SELF, PETSC_ERR_LIB, "impl->Q = 0");
441: PetscCall(PetscKokkosInitializeCheck());
443: /* Get CSR data */
444: PetscInt rstart, rend, i, nnz;
445: PetscCall(MatGetOwnershipRange(impl->Q, &rstart, &rend));
446: nrows = rend - rstart;
448: /* Create IS for local observations needed by this process */
449: /* We need to find all unique column indices in the local rows of Q */
450: {
451: PetscInt *obs_indices;
452: PetscInt n_obs_local_total = 0;
453: PetscInt max_obs = nrows * impl->n_obs_vertex;
454: PetscInt count = 0;
455: PetscHMapI ht;
456: PetscHashIter iter;
457: PetscBool missing;
459: PetscCall(PetscHMapICreate(&ht));
460: PetscCall(PetscMalloc1(max_obs, &obs_indices));
462: for (i = 0; i < nrows; i++) {
463: const PetscInt *cols;
464: const PetscScalar *vals;
465: PetscCall(MatGetRow(impl->Q, rstart + i, &nnz, &cols, &vals));
466: for (PetscInt k = 0; k < nnz; k++) {
467: PetscCall(PetscHMapIPut(ht, cols[k], &iter, &missing));
468: if (missing) {
469: obs_indices[count] = cols[k];
470: count++;
471: }
472: }
473: PetscCall(MatRestoreRow(impl->Q, rstart + i, &nnz, &cols, &vals));
474: }
475: n_obs_local_total = count;
477: /* Sort indices for consistent ordering */
478: PetscCall(PetscSortInt(n_obs_local_total, obs_indices));
480: /* Create IS and VecScatter */
481: PetscCall(ISCreateGeneral(PETSC_COMM_SELF, n_obs_local_total, obs_indices, PETSC_COPY_VALUES, &impl->obs_is_local));
483: /* Create global-to-local map for observations */
484: PetscCall(PetscHMapICreate(&impl->obs_g2l));
485: for (i = 0; i < n_obs_local_total; i++) {
486: PetscCall(PetscHMapIPut(impl->obs_g2l, obs_indices[i], &iter, &missing));
487: PetscCall(PetscHMapIIterSet(impl->obs_g2l, iter, i));
488: }
490: PetscCall(PetscFree(obs_indices));
491: PetscCall(PetscHMapIDestroy(&ht));
492: }
494: /* Create work vectors and scatter context */
495: {
496: PetscInt n_obs_local_total;
497: PetscCall(ISGetLocalSize(impl->obs_is_local, &n_obs_local_total));
499: PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->obs_work));
500: PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->y_mean_work));
501: PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->r_inv_sqrt_work));
503: Vec gvec;
504: IS is_to;
505: PetscCall(MatCreateVecs(H, NULL, &gvec)); /* Create template global vector (left vector = rows = observations) */
506: PetscCall(ISCreateStride(PETSC_COMM_SELF, n_obs_local_total, 0, 1, &is_to));
507: PetscCall(VecScatterCreate(gvec, impl->obs_is_local, impl->obs_work, is_to, &impl->obs_scat));
508: PetscCall(VecDestroy(&gvec));
509: PetscCall(ISDestroy(&is_to));
510: }
512: /* Define View types */
513: using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
514: using view_1d_scalar = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;
516: /* Allocate device views */
517: view_1d_int *d_Q_i = new view_1d_int("Q_i", nrows + 1);
518: view_1d_int *d_Q_j = new view_1d_int("Q_j", nrows * impl->n_obs_vertex);
519: view_1d_scalar *d_Q_a = new view_1d_scalar("Q_a", nrows * impl->n_obs_vertex);
521: /* Create host mirrors */
522: auto h_Q_i = Kokkos::create_mirror_view(*d_Q_i);
523: auto h_Q_j = Kokkos::create_mirror_view(*d_Q_j);
524: auto h_Q_a = Kokkos::create_mirror_view(*d_Q_a);
526: /* Fill host mirrors with LOCAL indices into obs_work */
527: h_Q_i(0) = 0;
528: for (i = 0; i < nrows; i++) {
529: const PetscInt *cols;
530: const PetscScalar *vals;
531: PetscCall(MatGetRow(impl->Q, rstart + i, &nnz, &cols, &vals));
532: h_Q_i(i + 1) = h_Q_i(i) + nnz;
533: for (PetscInt k = 0; k < nnz; k++) {
534: PetscInt local_idx;
535: PetscCall(ISLocate(impl->obs_is_local, cols[k], &local_idx));
536: PetscCheck(local_idx >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Observation index %" PetscInt_FMT " not found in local IS", cols[k]);
537: h_Q_j(h_Q_i(i) + k) = local_idx;
538: h_Q_a(h_Q_i(i) + k) = vals[k];
539: }
540: PetscCall(MatRestoreRow(impl->Q, rstart + i, &nnz, &cols, &vals));
541: }
543: /* Copy to device */
544: Kokkos::deep_copy(*d_Q_i, h_Q_i);
545: Kokkos::deep_copy(*d_Q_j, h_Q_j);
546: Kokkos::deep_copy(*d_Q_a, h_Q_a);
548: /* Store in impl */
549: PetscCheck(!impl->Q_device_i, PETSC_COMM_SELF, PETSC_ERR_LIB, "impl->Q = 0");
550: impl->Q_device_i = static_cast<void *>(d_Q_i);
551: impl->Q_device_j = static_cast<void *>(d_Q_j);
552: impl->Q_device_a = static_cast<void *>(d_Q_a);
553: PetscFunctionReturn(PETSC_SUCCESS);
554: }
556: PetscErrorCode PetscDALETKFDestroyLocalization_Kokkos(PetscDA_LETKF *impl)
557: {
558: PetscFunctionBegin;
559: PetscCall(VecDestroy(&impl->obs_work));
560: PetscCall(VecDestroy(&impl->y_mean_work));
561: PetscCall(VecDestroy(&impl->r_inv_sqrt_work));
562: PetscCall(VecScatterDestroy(&impl->obs_scat));
563: PetscCall(MatDestroy(&impl->Z_work));
564: PetscCall(PetscHMapIDestroy(&impl->obs_g2l));
565: if (impl->Q_device_i) {
566: using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
567: delete static_cast<view_1d_int *>(impl->Q_device_i);
568: impl->Q_device_i = NULL;
569: }
570: if (impl->Q_device_j) {
571: using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
572: delete static_cast<view_1d_int *>(impl->Q_device_j);
573: impl->Q_device_j = NULL;
574: }
575: if (impl->Q_device_a) {
576: using view_1d_scalar = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;
577: delete static_cast<view_1d_scalar *>(impl->Q_device_a);
578: impl->Q_device_a = NULL;
579: }
581: /* Destroy solver handle and workspace */
582: if (impl->eigen_work) {
583: EigenWorkspace *work = static_cast<EigenWorkspace *>(impl->eigen_work);
585: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
586: #if defined(KOKKOS_ENABLE_CUDA)
587: PetscCallCUDA(cudaFree(work->d_A_contig));
588: PetscCallCUDA(cudaFree(work->d_W_contig));
589: PetscCallCUDA(cudaFree(work->d_work));
590: PetscCallCUDA(cudaFree(work->d_info));
591: if (work->syevj_params) cusolverDnDestroySyevjInfo(work->syevj_params);
592: #elif defined(KOKKOS_ENABLE_HIP)
593: PetscCallHIP(hipFree(work->d_A_contig));
594: PetscCallHIP(hipFree(work->d_W_contig));
595: PetscCallHIP(hipFree(work->d_work));
596: PetscCallHIP(hipFree(work->d_info));
597: #elif defined(KOKKOS_ENABLE_SYCL)
598: if (impl->solver_handle) {
599: sycl::queue *q = static_cast<sycl::queue *>(impl->solver_handle);
600: if (work->d_A_contig) sycl::free(work->d_A_contig, *q);
601: if (work->d_W_contig) sycl::free(work->d_W_contig, *q);
602: if (work->d_work) sycl::free(work->d_work, *q);
603: if (work->d_info) sycl::free(work->d_info, *q);
604: }
605: #endif
606: #else
607: #if defined(PETSC_USE_COMPLEX)
608: PetscCall(PetscFree4(work->all_v, work->all_lambda, work->all_work, work->all_rwork));
609: #else
610: PetscCall(PetscFree3(work->all_v, work->all_lambda, work->all_work));
611: #endif
612: #endif
614: delete work;
615: impl->eigen_work = NULL;
616: }
618: if (impl->solver_handle) {
619: #if defined(KOKKOS_ENABLE_CUDA)
620: cusolverDnDestroy(static_cast<cusolverDnHandle_t>(impl->solver_handle));
621: #elif defined(KOKKOS_ENABLE_HIP)
622: rocblas_destroy_handle(static_cast<rocblas_handle>(impl->solver_handle));
623: #elif defined(KOKKOS_ENABLE_SYCL)
624: delete static_cast<sycl::queue *>(impl->solver_handle);
625: #endif
626: impl->solver_handle = NULL;
627: }
628: PetscFunctionReturn(PETSC_SUCCESS);
629: }
631: /* ========================================================================== */
632: /* LETKF Local Analysis (Main Function) */
633: /* ========================================================================== */
635: /*
636: PetscDALETKFLocalAnalysis_GPU - Performs local LETKF analysis for all grid points (Kokkos version)
638: Input Parameters:
639: + da - the PetscDA context
640: . impl - LETKF implementation data
641: . m - ensemble size
642: . n_vertices - number of grid points
643: . X - global anomaly matrix (state_size x m)
644: . observation - observation vector
645: . Z_global - global observation ensemble (obs_size x m)
646: . y_mean_global - global observation mean
647: - r_inv_sqrt_global - global R^{-1/2}
649: Output:
650: . da->ensemble - updated with analysis ensemble
652: Notes:
653: This function performs the local analysis loop for LETKF, processing each grid point
654: independently using its local observations defined by the localization matrix Q.
655: This is the CPU version that does not use Kokkos acceleration.
657: All local analysis workspace objects (Z_local, S_local, T_sqrt_local, G_local, y_local,
658: y_mean_local, delta_scaled_local, r_inv_sqrt_local, w_local, s_transpose_delta) are
659: created with PETSC_COMM_SELF because the analysis at each vertex is serial and independent.
660: */
661: PetscErrorCode PetscDALETKFLocalAnalysis_GPU(PetscDA da, PetscDA_LETKF *impl, PetscInt m, PetscInt n_vertices, Mat X, Vec observation, Mat Z_global, Vec y_mean_global, Vec r_inv_sqrt_global)
662: {
663: PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
664: PetscInt ndof;
665: PetscReal sqrt_m_minus_1, scale, inflation_inv;
667: PetscFunctionBegin;
668: ndof = da->ndof;
669: scale = 1.0 / PetscSqrtReal((PetscReal)(m - 1));
670: sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
671: inflation_inv = 1.0 / en->inflation; /* (1/rho) for T matrix: T = (1/rho)I + S^T*S */
673: /* ===================================================================== */
674: /* Step 2.1.1: Create batched workspace for ALL grid points */
675: /* ===================================================================== */
676: /*
677: NOTE ON PARALLELISM STRATEGY:
678: We use Kokkos::RangePolicy over grid points (n_vertices) combined with KokkosBatched::Serial kernels.
679: Since the data layout is LayoutLeft (Column-Major) to match PETSc/LAPACK, the index 'i' (grid point)
680: is the fastest varying index (stride 1).
682: RangePolicy maps consecutive threads to consecutive 'i', ensuring perfect memory coalescing
683: when accessing arrays like S_batch(i, p, j).
685: Using TeamPolicy/TeamVectorRange to parallelize inner loops (m or p) would assign a team to 'i',
686: causing threads within the team to access S_batch with stride 'n_vertices', which leads to
687: uncoalesced memory access and poor performance on GPUs.
689: Therefore, RangePolicy + SerialGemm is the optimal strategy for this data layout.
690: */
691: using exec_space = Kokkos::DefaultExecutionSpace;
692: using view_3d = Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, exec_space>;
693: using view_2d = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>;
695: /* ===================================================================== */
696: /* Step 2.1.2a: Pre-extract Q matrix CSR data for device access */
697: /* ===================================================================== */
698: using view_1d_int_const = Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
699: using view_1d_scalar_const = Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
700: using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
701: using view_1d_scalar = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;
703: view_1d_int_const Q_i_view;
704: view_1d_int_const Q_j_view;
705: view_1d_scalar_const Q_a_view;
707: if (impl->Q_device_i) {
708: /* Use pre-allocated device views */
709: view_1d_int *d_Q_i = static_cast<view_1d_int *>(impl->Q_device_i);
710: view_1d_int *d_Q_j = static_cast<view_1d_int *>(impl->Q_device_j);
711: view_1d_scalar *d_Q_a = static_cast<view_1d_scalar *>(impl->Q_device_a);
713: Q_i_view = view_1d_int_const(d_Q_i->data(), d_Q_i->extent(0));
714: Q_j_view = view_1d_int_const(d_Q_j->data(), d_Q_j->extent(0));
715: Q_a_view = view_1d_scalar_const(d_Q_a->data(), d_Q_a->extent(0));
716: } else {
717: /* Fallback to host pointers (unsafe if not UVM) */
718: PetscCheck(PETSC_FALSE, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Q matrix must be setup with PetscDALETKFSetupLocalization_Kokkos");
719: }
721: /* Get global observation data arrays */
722: const PetscScalar *z_global_array, *y_global_array, *y_mean_global_array, *r_inv_sqrt_global_array;
723: PetscInt lda_z_global;
724: PetscMemType z_mem_type, y_mem_type, y_mean_mem_type, r_inv_sqrt_mem_type;
726: PetscCall(MatDenseGetArrayReadAndMemType(Z_global, &z_global_array, &z_mem_type));
727: PetscCall(VecGetArrayReadAndMemType(observation, &y_global_array, &y_mem_type));
728: PetscCall(VecGetArrayReadAndMemType(y_mean_global, &y_mean_global_array, &y_mean_mem_type));
729: PetscCall(VecGetArrayReadAndMemType(r_inv_sqrt_global, &r_inv_sqrt_global_array, &r_inv_sqrt_mem_type));
730: PetscCall(MatDenseGetLDA(Z_global, &lda_z_global));
732: /* Handle memory mirroring for observation data */
733: Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> z_managed;
734: Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space> y_managed;
735: Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space> y_mean_managed;
736: Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space> r_inv_sqrt_managed;
738: const PetscScalar *z_ptr = z_global_array;
739: const PetscScalar *y_ptr = y_global_array;
740: const PetscScalar *y_mean_ptr = y_mean_global_array;
741: const PetscScalar *r_inv_sqrt_ptr = r_inv_sqrt_global_array;
743: if (z_mem_type == PETSC_MEMTYPE_HOST) {
744: z_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("z_managed", lda_z_global, m);
745: Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(z_global_array, lda_z_global, m);
746: Kokkos::deep_copy(z_managed, src);
747: z_ptr = z_managed.data();
748: }
749: if (y_mem_type == PETSC_MEMTYPE_HOST) {
750: y_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("y_managed", lda_z_global);
751: Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(y_global_array, lda_z_global);
752: Kokkos::deep_copy(y_managed, src);
753: y_ptr = y_managed.data();
754: }
755: if (y_mean_mem_type == PETSC_MEMTYPE_HOST) {
756: y_mean_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("y_mean_managed", lda_z_global);
757: Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(y_mean_global_array, lda_z_global);
758: Kokkos::deep_copy(y_mean_managed, src);
759: y_mean_ptr = y_mean_managed.data();
760: }
761: if (r_inv_sqrt_mem_type == PETSC_MEMTYPE_HOST) {
762: r_inv_sqrt_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("r_inv_sqrt_managed", lda_z_global);
763: Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(r_inv_sqrt_global_array, lda_z_global);
764: Kokkos::deep_copy(r_inv_sqrt_managed, src);
765: r_inv_sqrt_ptr = r_inv_sqrt_managed.data();
766: }
768: /* Create unmanaged Kokkos views for global observation data */
769: using view_2d_unmanaged = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
770: using view_1d_unmanaged = Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
772: view_2d_unmanaged Z_global_view(z_ptr, lda_z_global, m);
773: view_1d_unmanaged y_global_view(y_ptr, lda_z_global);
774: view_1d_unmanaged y_mean_global_view(y_mean_ptr, lda_z_global);
775: view_1d_unmanaged r_inv_sqrt_global_view(r_inv_sqrt_ptr, lda_z_global);
777: /* Get access to global X matrix and mean vector */
778: const PetscScalar *x_array, *mean_array;
779: PetscScalar *e_array;
780: PetscInt lda_x, lda_e;
781: PetscMemType x_mem_type, mean_mem_type, e_mem_type;
783: PetscCall(MatDenseGetArrayReadAndMemType(X, &x_array, &x_mem_type));
784: PetscCall(VecGetArrayReadAndMemType(impl->mean, &mean_array, &mean_mem_type));
785: PetscCall(MatDenseGetArrayWriteAndMemType(en->ensemble, &e_array, &e_mem_type));
786: PetscCall(MatDenseGetLDA(X, &lda_x));
787: PetscCall(MatDenseGetLDA(en->ensemble, &lda_e));
789: /* Handle memory mirroring for state data */
790: Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> x_managed;
791: Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space> mean_managed;
792: Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> e_managed;
794: const PetscScalar *x_ptr = x_array;
795: const PetscScalar *mean_ptr = mean_array;
796: PetscScalar *e_ptr = e_array;
797: bool e_is_copy = false;
799: if (x_mem_type == PETSC_MEMTYPE_HOST) {
800: x_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("x_managed", lda_x, m);
801: Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(x_array, lda_x, m);
802: Kokkos::deep_copy(x_managed, src);
803: x_ptr = x_managed.data();
804: }
805: if (mean_mem_type == PETSC_MEMTYPE_HOST) {
806: mean_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("mean_managed", lda_x);
807: Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(mean_array, lda_x);
808: Kokkos::deep_copy(mean_managed, src);
809: mean_ptr = mean_managed.data();
810: }
811: if (e_mem_type == PETSC_MEMTYPE_HOST) {
812: e_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("e_managed", lda_e, m);
813: Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(e_array, lda_e, m);
814: Kokkos::deep_copy(e_managed, src);
815: e_ptr = e_managed.data();
816: e_is_copy = true;
817: }
819: /* Create unmanaged Kokkos views for global data */
820: using view_2d_unmanaged_write = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
821: view_2d_unmanaged X_view(const_cast<PetscScalar *>(x_ptr), lda_x, m);
822: view_1d_unmanaged mean_view(mean_ptr, lda_x);
823: view_2d_unmanaged_write E_view(e_ptr, lda_e, m);
825: /* Determine chunk size to avoid OOM on large grids */
826: PetscInt chunk_size;
827: if (impl->batch_size > 0) {
828: chunk_size = impl->batch_size;
829: } else {
830: /* Target ~2GB workspace. Approx memory per point: m*m*8 (T) + p*m*8 (Z) */
831: /* With reuse: m*m*8 + p*m*8 */
832: PetscInt mem_per_point = sizeof(PetscScalar) * (m * m + impl->n_obs_vertex * m);
833: chunk_size = (PetscInt)(2.0 * 1024 * 1024 * 1024 / mem_per_point);
834: /* Clamp to reasonable max to avoid huge allocations even if memory allows */
835: if (chunk_size > 32768) chunk_size = 32768;
836: }
838: if (chunk_size < 1) chunk_size = 1;
839: if (chunk_size > n_vertices) chunk_size = n_vertices;
841: /* OPTIMIZATION: Create device solver handle once, reuse across chunks */
842: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
843: #if defined(KOKKOS_ENABLE_CUDA)
844: cusolverDnHandle_t device_handle = nullptr;
845: cusolverStatus_t cusolver_status;
846: if (impl->solver_handle) {
847: device_handle = static_cast<cusolverDnHandle_t>(impl->solver_handle);
848: } else {
849: cusolver_status = cusolverDnCreate(&device_handle);
850: PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDnCreate failed");
851: impl->solver_handle = static_cast<void *>(device_handle);
852: }
853: #elif defined(KOKKOS_ENABLE_HIP)
854: rocblas_handle device_handle = nullptr;
855: if (impl->solver_handle) {
856: device_handle = static_cast<rocblas_handle>(impl->solver_handle);
857: } else {
858: rocblas_status hip_status = rocblas_create_handle(&device_handle);
859: PetscCheck(hip_status == rocblas_status_success, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocblas_create_handle failed");
860: impl->solver_handle = static_cast<void *>(device_handle);
861: }
862: #elif defined(KOKKOS_ENABLE_SYCL)
863: sycl::queue *device_handle = nullptr;
864: if (impl->solver_handle) {
865: device_handle = static_cast<sycl::queue *>(impl->solver_handle);
866: } else {
867: device_handle = new sycl::queue(sycl::gpu_selector_v);
868: impl->solver_handle = static_cast<void *>(device_handle);
869: }
870: #endif
871: #endif
873: /* ===================================================================== */
874: /* OPTIMIZATION: Hoist allocations outside the chunk loop */
875: /* ===================================================================== */
876: /* Allocate Kokkos Views once for the maximum chunk size */
877: PetscInt n_obs_vertex_copy = impl->n_obs_vertex;
879: EigenWorkspace *eigen_work = static_cast<EigenWorkspace *>(impl->eigen_work);
880: if (!eigen_work) {
881: eigen_work = new EigenWorkspace();
882: impl->eigen_work = static_cast<void *>(eigen_work);
883: }
885: /* Check if reallocation is needed */
886: if (eigen_work->max_chunk_size < chunk_size || eigen_work->m != m || eigen_work->n_obs_vertex != n_obs_vertex_copy) {
887: /* Free old device workspace if exists */
888: #if defined(KOKKOS_ENABLE_CUDA)
889: PetscCallCUDA(cudaFree(eigen_work->d_work));
890: PetscCallCUDA(cudaFree(eigen_work->d_info));
891: PetscCallCUDA(cudaFree(eigen_work->d_A_contig));
892: PetscCallCUDA(cudaFree(eigen_work->d_W_contig));
893: if (eigen_work->syevj_params) cusolverDnDestroySyevjInfo(eigen_work->syevj_params);
894: eigen_work->syevj_params = nullptr;
895: #elif defined(KOKKOS_ENABLE_HIP)
896: PetscCallHIP(hipFree(eigen_work->d_work));
897: PetscCallHIP(hipFree(eigen_work->d_info));
898: PetscCallHIP(hipFree(eigen_work->d_A_contig));
899: PetscCallHIP(hipFree(eigen_work->d_W_contig));
900: #elif defined(KOKKOS_ENABLE_SYCL)
901: if (eigen_work->d_work) sycl::free(eigen_work->d_work, *device_handle);
902: if (eigen_work->d_info) sycl::free(eigen_work->d_info, *device_handle);
903: if (eigen_work->d_A_contig) sycl::free(eigen_work->d_A_contig, *device_handle);
904: if (eigen_work->d_W_contig) sycl::free(eigen_work->d_W_contig, *device_handle);
905: #endif
907: #if !defined(KOKKOS_ENABLE_CUDA) && !defined(KOKKOS_ENABLE_HIP) && !defined(KOKKOS_ENABLE_SYCL)
908: #if defined(PETSC_USE_COMPLEX)
909: if (eigen_work->all_v) PetscCall(PetscFree4(eigen_work->all_v, eigen_work->all_lambda, eigen_work->all_work, eigen_work->all_rwork));
910: #else
911: if (eigen_work->all_v) PetscCall(PetscFree3(eigen_work->all_v, eigen_work->all_lambda, eigen_work->all_work));
912: #endif
913: #endif
915: /* Update dimensions */
916: eigen_work->max_chunk_size = chunk_size;
917: eigen_work->m = m;
918: eigen_work->n_obs_vertex = n_obs_vertex_copy;
920: /* Allocate Kokkos Views */
921: eigen_work->Z_batch = view_3d("Z_batch", chunk_size, n_obs_vertex_copy, m);
922: eigen_work->S_batch = eigen_work->Z_batch;
923: eigen_work->T_batch = view_3d("T_batch", chunk_size, m, m);
924: eigen_work->V_batch = eigen_work->T_batch;
925: eigen_work->Lambda_batch = view_2d("Lambda_batch", chunk_size, m);
926: eigen_work->T_sqrt_batch = view_3d("T_sqrt_batch", chunk_size, m, m);
927: eigen_work->w_batch = view_2d("w_batch", chunk_size, m);
928: eigen_work->delta_batch = view_2d("delta_batch", chunk_size, n_obs_vertex_copy);
929: eigen_work->y_batch = view_2d("y_batch", chunk_size, n_obs_vertex_copy);
930: eigen_work->y_mean_batch = view_2d("y_mean_batch", chunk_size, n_obs_vertex_copy);
931: eigen_work->r_inv_sqrt_batch = view_2d("r_inv_sqrt_batch", chunk_size, n_obs_vertex_copy);
932: eigen_work->temp1_batch = view_2d("temp1_batch", chunk_size, m);
933: eigen_work->temp2_batch = view_2d("temp2_batch", chunk_size, m);
934: eigen_work->inv_sqrt_lambda_batch = view_2d("inv_sqrt_lambda_batch", chunk_size, m);
936: /* Allocate solver workspace */
937: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
938: #if defined(KOKKOS_ENABLE_CUDA)
939: {
940: /* Create syevj params */
941: cusolver_status = cusolverDnCreateSyevjInfo(&eigen_work->syevj_params);
942: PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDnCreateSyevjInfo failed");
944: /* Set default params */
945: cusolverDnXsyevjSetTolerance(eigen_work->syevj_params, 1e-7);
946: cusolverDnXsyevjSetMaxSweeps(eigen_work->syevj_params, 100);
947: cusolverDnXsyevjSetSortEig(eigen_work->syevj_params, 1); /* Sort eigenvalues */
949: /* Query workspace size */
950: PetscScalar *d_A = eigen_work->T_batch.data();
951: PetscScalar *d_W = eigen_work->Lambda_batch.data();
952: int lwork;
953: #if defined(PETSC_USE_REAL_SINGLE)
954: cusolver_status = cusolverDnSsyevjBatched_bufferSize(device_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, d_A, m, d_W, &lwork, eigen_work->syevj_params, chunk_size);
955: #else
956: cusolver_status = cusolverDnDsyevjBatched_bufferSize(device_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, d_A, m, d_W, &lwork, eigen_work->syevj_params, chunk_size);
957: #endif
958: PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDn*syevjBatched_bufferSize failed");
959: eigen_work->lwork_device = lwork;
961: /* Allocate workspace */
962: PetscCallCUDA(cudaMalloc(&eigen_work->d_work, sizeof(PetscScalar) * lwork));
963: PetscCallCUDA(cudaMalloc(&eigen_work->d_info, sizeof(int) * chunk_size));
964: PetscCallCUDA(cudaMalloc(&eigen_work->d_A_contig, sizeof(PetscScalar) * chunk_size * m * m));
965: PetscCallCUDA(cudaMalloc(&eigen_work->d_W_contig, sizeof(PetscScalar) * chunk_size * m));
966: }
967: #elif defined(KOKKOS_ENABLE_HIP)
968: {
969: /* rocsolver_dsyevd does not support size query via -1.
970: We use a safe upper bound estimate based on LAPACK dsyevd requirements.
971: */
972: #if defined(PETSC_USE_COMPLEX)
973: int lwork = 0; /* Complex not supported on device */
974: #else
975: int lwork = 1 + 6 * m + 2 * m * m;
976: #endif
977: eigen_work->lwork_device = lwork;
979: /* Allocate workspace */
980: if (lwork > 0) {
981: PetscCallHIP(hipMalloc(&eigen_work->d_work, sizeof(PetscScalar) * lwork));
982: PetscCallHIP(hipMalloc(&eigen_work->d_info, sizeof(int) * chunk_size));
983: PetscCallHIP(hipMalloc(&eigen_work->d_A_contig, sizeof(PetscScalar) * chunk_size * m * m));
984: PetscCallHIP(hipMalloc(&eigen_work->d_W_contig, sizeof(PetscScalar) * chunk_size * m));
985: }
986: }
987: #elif defined(KOKKOS_ENABLE_SYCL)
988: {
989: /* Query workspace size for oneapi::mkl::lapack::syevd */
990: /* For syevd, workspace size is typically: */
991: /* lwork >= 1 + 6*n + 2*n*n for real, or */
992: /* lwork >= 2*n + n*n for complex */
993: int lwork;
994: #if defined(PETSC_USE_COMPLEX)
995: lwork = 2 * m + m * m;
996: #else
997: lwork = 1 + 6 * m + 2 * m * m;
998: #endif
999: eigen_work->lwork_device = lwork;
1001: /* Allocate workspace using SYCL malloc_device */
1002: eigen_work->d_work = sycl::malloc_device<PetscScalar>(lwork, *device_handle);
1003: eigen_work->d_info = sycl::malloc_device<int>(chunk_size, *device_handle);
1004: eigen_work->d_A_contig = sycl::malloc_device<PetscScalar>(chunk_size * m * m, *device_handle);
1005: eigen_work->d_W_contig = sycl::malloc_device<PetscScalar>(chunk_size * m, *device_handle);
1006: PetscCheck(eigen_work->d_work && eigen_work->d_info && eigen_work->d_A_contig && eigen_work->d_W_contig, PETSC_COMM_SELF, PETSC_ERR_MEM, "SYCL memory allocation failed");
1007: }
1008: #endif
1009: #else
1010: {
1011: PetscBLASInt n_blas;
1012: PetscCall(PetscBLASIntCast(m, &n_blas));
1013: eigen_work->n_blas = n_blas;
1015: /* Query workspace size */
1016: PetscBLASInt lwork_query = -1;
1017: PetscScalar work_query;
1018: PetscBLASInt info;
1019: #if defined(PETSC_USE_COMPLEX)
1020: PetscReal rwork_query;
1021: LAPACKsyev_("V", "U", &n_blas, &work_query, &n_blas, &rwork_query, &work_query, &lwork_query, &rwork_query, &info);
1022: #else
1023: LAPACKsyev_("V", "U", &n_blas, &work_query, &n_blas, &work_query, &work_query, &lwork_query, &info);
1024: #endif
1025: PetscCheck(info == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "LAPACK workspace query failed");
1026: eigen_work->lwork = (PetscBLASInt)PetscRealPart(work_query);
1028: /* Allocate workspace */
1029: #if defined(PETSC_USE_COMPLEX)
1030: PetscCall(PetscMalloc4(chunk_size * m * m, &eigen_work->all_v, chunk_size * m, &eigen_work->all_lambda, chunk_size * eigen_work->lwork, &eigen_work->all_work, chunk_size * (3 * m - 2), &eigen_work->all_rwork));
1031: #else
1032: PetscCall(PetscMalloc3(chunk_size * m * m, &eigen_work->all_v, chunk_size * m, &eigen_work->all_lambda, chunk_size * eigen_work->lwork, &eigen_work->all_work));
1033: #endif
1034: }
1035: #endif
1036: }
1038: /* Create aliases for current function use */
1039: view_3d Z_batch_alloc = eigen_work->Z_batch;
1040: view_3d S_batch_alloc = eigen_work->S_batch;
1041: view_3d T_batch_alloc = eigen_work->T_batch;
1042: view_3d V_batch_alloc = eigen_work->V_batch;
1043: view_2d Lambda_batch_alloc = eigen_work->Lambda_batch;
1044: view_3d T_sqrt_batch_alloc = eigen_work->T_sqrt_batch;
1045: view_2d w_batch_alloc = eigen_work->w_batch;
1046: view_2d delta_batch_alloc = eigen_work->delta_batch;
1047: view_2d y_batch_alloc = eigen_work->y_batch;
1048: view_2d y_mean_batch_alloc = eigen_work->y_mean_batch;
1049: view_2d r_inv_sqrt_batch_alloc = eigen_work->r_inv_sqrt_batch;
1050: view_2d temp1_batch_alloc = eigen_work->temp1_batch;
1051: view_2d temp2_batch_alloc = eigen_work->temp2_batch;
1052: view_2d inv_sqrt_lambda_batch_alloc = eigen_work->inv_sqrt_lambda_batch;
1054: /* Loop over chunks */
1055: for (PetscInt chunk_start = 0; chunk_start < n_vertices; chunk_start += chunk_size) {
1056: PetscInt chunk_end = (chunk_start + chunk_size > n_vertices) ? n_vertices : chunk_start + chunk_size;
1057: PetscInt n_batch_current = chunk_end - chunk_start;
1059: /* Create subviews for current batch size */
1060: auto Z_batch = Kokkos::subview(Z_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1061: auto S_batch = Kokkos::subview(S_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1062: auto T_batch = Kokkos::subview(T_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1063: auto V_batch = Kokkos::subview(V_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1064: auto Lambda_batch = Kokkos::subview(Lambda_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1065: auto T_sqrt_batch = Kokkos::subview(T_sqrt_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1066: auto w_batch = Kokkos::subview(w_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1067: auto delta_batch = Kokkos::subview(delta_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1068: auto y_batch = Kokkos::subview(y_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1069: auto y_mean_batch = Kokkos::subview(y_mean_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1070: auto r_inv_sqrt_batch = Kokkos::subview(r_inv_sqrt_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1071: auto temp1_batch = Kokkos::subview(temp1_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1072: auto temp2_batch = Kokkos::subview(temp2_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1073: auto inv_sqrt_lambda_batch = Kokkos::subview(inv_sqrt_lambda_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1075: /* ===================================================================== */
1076: /* Step 2.1.2: Fused observation extraction and S/Delta computation */
1077: /* ===================================================================== */
1078: /* Extract local observations and immediately compute S and delta */
1079: /* This fusion eliminates one kernel launch and improves cache locality */
1080: Kokkos::parallel_for(
1081: "ExtractAndComputeSAndDelta", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i_local) {
1082: PetscInt i_global = chunk_start + i_local;
1083: /* Get Q row for this grid point using CSR format */
1084: PetscInt row_start = Q_i_view(i_global);
1085: PetscInt row_end = Q_i_view(i_global + 1);
1086: PetscInt ncols = row_end - row_start;
1088: /* Extract observations and compute S/delta for this grid point */
1089: for (PetscInt k = 0; k < ncols; k++) {
1090: PetscInt obs_idx = Q_j_view(row_start + k);
1091: PetscScalar weight = Q_a_view(row_start + k);
1093: /* Extract observation vectors */
1094: PetscScalar y_val = y_global_view(obs_idx);
1095: PetscScalar y_mean_val = y_mean_global_view(obs_idx);
1096: PetscScalar r_inv_sqrt = r_inv_sqrt_global_view(obs_idx) * Kokkos::sqrt(PetscRealPart(weight));
1098: /* Store for later use if needed */
1099: y_batch(i_local, k) = y_val;
1100: y_mean_batch(i_local, k) = y_mean_val;
1101: r_inv_sqrt_batch(i_local, k) = r_inv_sqrt;
1103: /* Compute delta immediately: delta = R^{-1/2}(y - y_mean) */
1104: delta_batch(i_local, k) = (y_val - y_mean_val) * r_inv_sqrt;
1106: /* Compute S row: S = R^{-1/2}(Z - y_mean * 1')/sqrt(m-1) */
1107: PetscScalar scale_factor = scale * r_inv_sqrt;
1108: for (int j = 0; j < m; j++) {
1109: PetscScalar z_val = Z_global_view(obs_idx, j);
1110: Z_batch(i_local, k, j) = z_val; /* Store Z for potential later use */
1111: S_batch(i_local, k, j) = (z_val - y_mean_val) * scale_factor;
1112: }
1113: }
1114: });
1115: Kokkos::fence();
1117: /* DEBUG: Check S for NaNs */
1118: if (PetscDefined(USE_DEBUG)) {
1119: PetscInt nan_count = 0;
1120: Kokkos::parallel_reduce(
1121: "CheckS", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1122: KOKKOS_LAMBDA(const int i, PetscInt &l_count) {
1123: for (int j = 0; j < n_obs_vertex_copy; j++) {
1124: for (int k = 0; k < m; k++) {
1125: if (S_batch(i, j, k) != S_batch(i, j, k)) l_count++;
1126: }
1127: }
1128: },
1129: nan_count);
1130: PetscCheck(nan_count == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " NaNs in S_batch at chunk_start %" PetscInt_FMT, nan_count, chunk_start);
1131: }
1133: /* ===================================================================== */
1134: /* Step 2.1.4: Optimized T matrix formation (T = (1/rho)I + S^T * S) */
1135: /* ===================================================================== */
1136: /* Compute T_i = (1/rho)I + S_i^T * S_i for current chunk */
1137: /* Exploit symmetry: only compute upper triangle, then copy to lower */
1138: /* This reduces operations by ~50% */
1139: Kokkos::parallel_for(
1140: "ComputeAllTMatrices", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i) {
1141: auto S_i = Kokkos::subview(S_batch, i, Kokkos::ALL(), Kokkos::ALL());
1142: auto T_i = Kokkos::subview(T_batch, i, Kokkos::ALL(), Kokkos::ALL());
1144: /* Compute upper triangle of T_i = (1/rho)I + S_i^T * S_i */
1145: /* T_i(j,k) = (1/rho)*delta_jk + sum_p S_i(p,j) * S_i(p,k) for j <= k */
1146: for (int j = 0; j < m; j++) {
1147: for (int k = j; k < m; k++) {
1148: PetscScalar sum = (j == k) ? inflation_inv : 0.0;
1149: for (int p = 0; p < n_obs_vertex_copy; p++) sum += S_i(p, j) * S_i(p, k);
1150: T_i(j, k) = sum;
1151: }
1152: }
1154: /* Copy upper triangle to lower triangle (T is symmetric) */
1155: for (int j = 0; j < m; j++) {
1156: for (int k = 0; k < j; k++) T_i(j, k) = T_i(k, j);
1157: }
1158: });
1159: Kokkos::fence();
1161: /* DEBUG: Check T for NaNs */
1162: if (PetscDefined(USE_DEBUG)) {
1163: PetscInt nan_count = 0;
1164: Kokkos::parallel_reduce(
1165: "CheckT", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1166: KOKKOS_LAMBDA(const int i, PetscInt &l_count) {
1167: for (int j = 0; j < m; j++) {
1168: for (int k = 0; k < m; k++) {
1169: if (T_batch(i, j, k) != T_batch(i, j, k)) l_count++;
1170: }
1171: }
1172: },
1173: nan_count);
1174: PetscCheck(nan_count == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " NaNs in T_batch at chunk_start %" PetscInt_FMT, nan_count, chunk_start);
1175: }
1177: /* ===================================================================== */
1178: /* Step 3.1.1: Batched eigendecomposition for current chunk */
1179: /* ===================================================================== */
1180: /* Compute T_i = V_i * Lambda_i * V_i^T for current chunk */
1181: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
1182: PetscCall(BatchedEigenSolve(T_batch, Lambda_batch, V_batch, n_batch_current, m, device_handle, eigen_work));
1183: #else
1184: PetscCall(BatchedEigenSolve(T_batch, Lambda_batch, V_batch, n_batch_current, m, eigen_work));
1185: #endif
1187: /* DEBUG: Check Lambda for NaNs or negative values */
1188: if (PetscDefined(USE_DEBUG)) {
1189: PetscInt bad_lambda = 0;
1190: Kokkos::parallel_reduce(
1191: "CheckLambda", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1192: KOKKOS_LAMBDA(const int i, PetscInt &l_count) {
1193: for (int k = 0; k < m; k++) {
1194: if (Lambda_batch(i, k) != Lambda_batch(i, k) || PetscRealPart(Lambda_batch(i, k)) < -1e-8) l_count++;
1195: }
1196: },
1197: bad_lambda);
1198: PetscCheck(bad_lambda == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " bad eigenvalues (NaN or negative) at chunk_start %" PetscInt_FMT, bad_lambda, chunk_start);
1199: }
1201: /* ===================================================================== */
1202: /* Step 3.1.2: Precompute w and inv_sqrt_lambda for ensemble update */
1203: /* ===================================================================== */
1204: /* Compute w_i = T_i^{-1} * (S_i^T * delta_i) using eigendecomposition */
1205: /* Precompute 1/sqrt(Lambda) for use in ensemble update */
1206: Kokkos::parallel_for(
1207: "ComputeWeightsAndInvSqrtLambda", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i) {
1208: auto S_i = Kokkos::subview(S_batch, i, Kokkos::ALL(), Kokkos::ALL());
1209: auto V_i = Kokkos::subview(V_batch, i, Kokkos::ALL(), Kokkos::ALL());
1210: auto Lambda_i = Kokkos::subview(Lambda_batch, i, Kokkos::ALL());
1211: auto delta_i = Kokkos::subview(delta_batch, i, Kokkos::ALL());
1212: auto w_i = Kokkos::subview(w_batch, i, Kokkos::ALL());
1213: auto inv_sqrt_lambda_i = Kokkos::subview(inv_sqrt_lambda_batch, i, Kokkos::ALL());
1214: auto temp1 = Kokkos::subview(temp1_batch, i, Kokkos::ALL());
1215: auto temp2 = Kokkos::subview(temp2_batch, i, Kokkos::ALL());
1217: /* 1. Compute w_i = V * L^-1 * V^T * S^T * delta */
1218: /* Step 1a: temp1 = S^T * delta using KokkosBlas::gemv for better vectorization */
1219: KokkosBlas::SerialGemv<KokkosBlas::Trans::Transpose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, S_i, delta_i, 0.0, temp1);
1221: /* Step 1b: temp2 = V^T * temp1 using KokkosBlas::gemv for better vectorization */
1222: KokkosBlas::SerialGemv<KokkosBlas::Trans::Transpose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, V_i, temp1, 0.0, temp2);
1224: /* Step 1c: temp2 = temp2 / Lambda */
1225: for (int j = 0; j < m; j++) temp2(j) /= (Lambda_i(j) + 1.0e-14);
1227: /* Step 1d: w = V * temp2 using KokkosBlas::gemv for better vectorization */
1228: KokkosBlas::SerialGemv<KokkosBlas::Trans::NoTranspose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, V_i, temp2, 0.0, w_i);
1230: /* 2. Precompute 1/sqrt(Lambda) for ensemble update */
1231: for (int p = 0; p < m; p++) inv_sqrt_lambda_i(p) = 1.0 / Kokkos::sqrt(PetscRealPart(Lambda_i(p)) + 1.0e-14);
1232: });
1233: Kokkos::fence();
1235: /* ===================================================================== */
1236: /* Step 3.1.3: Fused G computation and ensemble update */
1237: /* ===================================================================== */
1238: /* Compute E[i,:] = mean[i] + X[i,:] * G_i on-the-fly */
1239: /* G_i is computed column-by-column and immediately applied */
1240: /* This eliminates the need to store G_batch, saving m*m*n_batch memory */
1241: Kokkos::parallel_for(
1242: "FusedGComputeAndEnsembleUpdate", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i_local) {
1243: PetscInt i_global = chunk_start + i_local;
1245: auto X_i = Kokkos::subview(X_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof), Kokkos::ALL());
1246: auto E_i = Kokkos::subview(E_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof), Kokkos::ALL());
1247: auto mean_i = Kokkos::subview(mean_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof));
1249: auto V_i = Kokkos::subview(V_batch, i_local, Kokkos::ALL(), Kokkos::ALL());
1250: auto w_i = Kokkos::subview(w_batch, i_local, Kokkos::ALL());
1251: auto inv_sqrt_lambda_i = Kokkos::subview(inv_sqrt_lambda_batch, i_local, Kokkos::ALL());
1252: auto T_sqrt_i = Kokkos::subview(T_sqrt_batch, i_local, Kokkos::ALL(), Kokkos::ALL());
1254: /* Initialize E_i with mean */
1255: for (int row = 0; row < ndof; row++) {
1256: PetscScalar m_val = mean_i(row);
1257: for (int col = 0; col < m; col++) E_i(row, col) = m_val;
1258: }
1260: /* Compute T_sqrt = V * diag(1/sqrt(Lambda)) * V^T */
1261: /* Optimized: Exploit symmetry - only compute upper triangle, then copy to lower */
1262: /* T_sqrt(j,k) = sum_p V(j,p) * V(k,p) / sqrt(Lambda(p)) for j <= k */
1263: for (int j = 0; j < m; j++) {
1264: for (int k = j; k < m; k++) {
1265: PetscScalar sum = 0.0;
1266: for (int p = 0; p < m; p++) sum += V_i(j, p) * V_i(k, p) * inv_sqrt_lambda_i(p);
1267: T_sqrt_i(j, k) = sum;
1268: }
1269: }
1270: /* Copy upper triangle to lower triangle (T_sqrt is symmetric) */
1271: for (int j = 0; j < m; j++) {
1272: for (int k = 0; k < j; k++) T_sqrt_i(j, k) = T_sqrt_i(k, j);
1273: }
1275: /* Compute E_i += X_i * G_i column-by-column */
1276: /* G_i(:,k) = w_i + sqrt(m-1) * T_sqrt_i(:,k) */
1277: for (int k = 0; k < m; k++) {
1278: /* Compute column k of G on-the-fly */
1279: for (int row = 0; row < ndof; row++) {
1280: PetscScalar sum = 0.0;
1281: for (int j = 0; j < m; j++) {
1282: /* G_i(j,k) = w_i(j) + sqrt(m-1) * T_sqrt_i(j,k) */
1283: PetscScalar G_jk = w_i(j) + sqrt_m_minus_1 * T_sqrt_i(j, k);
1284: sum += X_i(row, j) * G_jk;
1285: }
1286: E_i(row, k) += sum;
1287: }
1288: }
1289: });
1290: Kokkos::fence();
1291: }
1293: /* Cleanup workspace */
1294: /* NOTE: Workspace is now persistent in impl->eigen_work and impl->solver_handle */
1295: /* It will be destroyed in PetscDALETKFDestroyLocalization_Kokkos */
1297: /* Copy back updated ensemble if needed */
1298: if (e_is_copy) {
1299: Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> dst(e_array, lda_e, m);
1300: Kokkos::deep_copy(dst, e_managed);
1301: }
1303: /* Restore arrays */
1304: PetscCall(MatDenseRestoreArrayWriteAndMemType(en->ensemble, &e_array));
1305: PetscCall(VecRestoreArrayReadAndMemType(impl->mean, &mean_array));
1306: PetscCall(MatDenseRestoreArrayReadAndMemType(X, &x_array));
1308: /* Restore global observation arrays */
1309: PetscCall(VecRestoreArrayReadAndMemType(r_inv_sqrt_global, &r_inv_sqrt_global_array));
1310: PetscCall(VecRestoreArrayReadAndMemType(y_mean_global, &y_mean_global_array));
1311: PetscCall(VecRestoreArrayReadAndMemType(observation, &y_global_array));
1312: PetscCall(MatDenseRestoreArrayReadAndMemType(Z_global, &z_global_array));
1314: /* Ensemble has been updated in batched form above */
1315: PetscCall(MatAssemblyBegin(en->ensemble, MAT_FINAL_ASSEMBLY));
1316: PetscCall(MatAssemblyEnd(en->ensemble, MAT_FINAL_ASSEMBLY));
1318: {
1319: MatInfo info;
1320: PetscReal flops = 0.0;
1321: PetscReal n_obs_total;
1323: if (impl->Q) {
1324: PetscCall(MatGetInfo(impl->Q, MAT_LOCAL, &info));
1325: n_obs_total = info.nz_used;
1326: } else {
1327: n_obs_total = 0.0;
1328: }
1330: /* Step 2.1.2: Fused observation extraction and S/Delta computation */
1331: flops += n_obs_total * (2.0 + 2.0 * m);
1333: /* Step 2.1.4: Optimized T matrix formation */
1334: flops += (PetscReal)n_vertices * m * (m + 1) * impl->n_obs_vertex;
1336: /* Step 3.1.2: Precompute w and inv_sqrt_lambda */
1337: flops += (PetscReal)n_vertices * (2.0 * m * impl->n_obs_vertex + 4.0 * m * m + 3.0 * m);
1339: /* Step 3.1.3: Fused G computation and ensemble update */
1340: /* T_sqrt: 1.5*m^3 + 1.5*m^2 */
1341: flops += (PetscReal)n_vertices * (1.5 * m * m * m + 1.5 * m * m);
1342: /* E update: ndof * m * (4*m + 1) */
1343: /* Note: G_jk computation (2 flops) is inside the inner loop, so it's 2*m*ndof*m */
1344: /* Matrix product X*G (2 flops) is also 2*m*ndof*m */
1345: flops += (PetscReal)n_vertices * ndof * m * (4.0 * m + 1.0);
1347: PetscCall(PetscLogGpuFlops(flops));
1348: }
1349: PetscFunctionReturn(PETSC_SUCCESS);
1350: }