Actual source code: fsolvebaij.F90
1: !
2: !
3: ! Fortran kernel for sparse triangular solve in the BAIJ matrix format
4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6: !
7: #include <petsc/finclude/petscsys.h>
8: !
10: pure subroutine FortranSolveBAIJ4Unroll(n, x, ai, aj, adiag, a, b)
11: use, intrinsic :: ISO_C_binding
12: implicit none(type, external)
13: MatScalar, intent(in) :: a(0:*)
14: PetscScalar, intent(inout) :: x(0:*)
15: PetscScalar, intent(in) :: b(0:*)
16: PetscInt, intent(in) :: n
17: PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
19: PetscInt :: i, j, jstart, jend
20: PetscInt :: idx, ax, jdx
21: PetscScalar :: s(0:3)
23: PETSC_AssertAlignx(16, a(1))
24: PETSC_AssertAlignx(16, x(1))
25: PETSC_AssertAlignx(16, b(1))
26: PETSC_AssertAlignx(16, ai(1))
27: PETSC_AssertAlignx(16, aj(1))
28: PETSC_AssertAlignx(16, adiag(1))
30: !
31: ! Forward Solve
32: !
33: x(0:3) = b(0:3)
34: idx = 0
35: do i = 1, n - 1
36: jstart = ai(i)
37: jend = adiag(i) - 1
38: ax = 16*jstart
39: idx = idx + 4
40: s(0:3) = b(idx + 0:idx + 3)
41: do j = jstart, jend
42: jdx = 4*aj(j)
44: s(0) = s(0) - (a(ax + 0)*x(jdx + 0) + a(ax + 4)*x(jdx + 1) + a(ax + 8)*x(jdx + 2) + a(ax + 12)*x(jdx + 3))
45: s(1) = s(1) - (a(ax + 1)*x(jdx + 0) + a(ax + 5)*x(jdx + 1) + a(ax + 9)*x(jdx + 2) + a(ax + 13)*x(jdx + 3))
46: s(2) = s(2) - (a(ax + 2)*x(jdx + 0) + a(ax + 6)*x(jdx + 1) + a(ax + 10)*x(jdx + 2) + a(ax + 14)*x(jdx + 3))
47: s(3) = s(3) - (a(ax + 3)*x(jdx + 0) + a(ax + 7)*x(jdx + 1) + a(ax + 11)*x(jdx + 2) + a(ax + 15)*x(jdx + 3))
48: ax = ax + 16
49: end do
50: x(idx + 0:idx + 3) = s(0:3)
51: end do
53: !
54: ! Backward solve the upper triangular
55: !
56: do i = n - 1, 0, -1
57: jstart = adiag(i) + 1
58: jend = ai(i + 1) - 1
59: ax = 16*jstart
60: s(0:3) = x(idx + 0:idx + 3)
61: do j = jstart, jend
62: jdx = 4*aj(j)
63: s(0) = s(0) - (a(ax + 0)*x(jdx + 0) + a(ax + 4)*x(jdx + 1) + a(ax + 8)*x(jdx + 2) + a(ax + 12)*x(jdx + 3))
64: s(1) = s(1) - (a(ax + 1)*x(jdx + 0) + a(ax + 5)*x(jdx + 1) + a(ax + 9)*x(jdx + 2) + a(ax + 13)*x(jdx + 3))
65: s(2) = s(2) - (a(ax + 2)*x(jdx + 0) + a(ax + 6)*x(jdx + 1) + a(ax + 10)*x(jdx + 2) + a(ax + 14)*x(jdx + 3))
66: s(3) = s(3) - (a(ax + 3)*x(jdx + 0) + a(ax + 7)*x(jdx + 1) + a(ax + 11)*x(jdx + 2) + a(ax + 15)*x(jdx + 3))
67: ax = ax + 16
68: end do
69: ax = 16*adiag(i)
70: x(idx + 0) = a(ax + 0)*s(0) + a(ax + 4)*s(1) + a(ax + 8)*s(2) + a(ax + 12)*s(3)
71: x(idx + 1) = a(ax + 1)*s(0) + a(ax + 5)*s(1) + a(ax + 9)*s(2) + a(ax + 13)*s(3)
72: x(idx + 2) = a(ax + 2)*s(0) + a(ax + 6)*s(1) + a(ax + 10)*s(2) + a(ax + 14)*s(3)
73: x(idx + 3) = a(ax + 3)*s(0) + a(ax + 7)*s(1) + a(ax + 11)*s(2) + a(ax + 15)*s(3)
74: idx = idx - 4
75: end do
76: end subroutine FortranSolveBAIJ4Unroll
78: ! version that does not call BLAS 2 operation for each row block
79: !
80: pure subroutine FortranSolveBAIJ4(n, x, ai, aj, adiag, a, b, w)
81: use, intrinsic :: ISO_C_binding
82: implicit none
83: MatScalar, intent(in) :: a(0:*)
84: PetscScalar, intent(inout) :: x(0:*), w(0:*)
85: PetscScalar, intent(in) :: b(0:*)
86: PetscInt, intent(in) :: n
87: PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)
89: PetscInt :: ii, jj, i, j
90: PetscInt :: jstart, jend, idx, ax, jdx, kdx, nn
91: PetscScalar :: s(0:3)
93: PETSC_AssertAlignx(16, a(1))
94: PETSC_AssertAlignx(16, w(1))
95: PETSC_AssertAlignx(16, x(1))
96: PETSC_AssertAlignx(16, b(1))
97: PETSC_AssertAlignx(16, ai(1))
98: PETSC_AssertAlignx(16, aj(1))
99: PETSC_AssertAlignx(16, adiag(1))
100: !
101: ! Forward Solve
102: !
103: x(0:3) = b(0:3)
104: idx = 0
105: do i = 1, n - 1
106: !
107: ! Pack required part of vector into work array
108: !
109: kdx = 0
110: jstart = ai(i)
111: jend = adiag(i) - 1
113: if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()'
115: do j = jstart, jend
116: jdx = 4*aj(j)
117: w(kdx:kdx + 3) = x(jdx:jdx + 3)
118: kdx = kdx + 4
119: end do
121: ax = 16*jstart
122: idx = idx + 4
123: s(0:3) = b(idx:idx + 3)
124: !
125: ! s = s - a(ax:)*w
126: !
127: nn = 4*(jend - jstart + 1) - 1
128: do ii = 0, 3
129: do jj = 0, nn
130: s(ii) = s(ii) - a(ax + 4*jj + ii)*w(jj)
131: end do
132: end do
134: x(idx:idx + 3) = s(0:3)
135: end do
136: !
137: ! Backward solve the upper triangular
138: !
139: do i = n - 1, 0, -1
140: jstart = adiag(i) + 1
141: jend = ai(i + 1) - 1
142: ax = 16*jstart
143: s(0:3) = x(idx:idx + 3)
144: !
145: ! Pack each chunk of vector needed
146: !
147: kdx = 0
148: if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()'
150: do j = jstart, jend
151: jdx = 4*aj(j)
152: w(kdx:kdx + 3) = x(jdx:jdx + 3)
153: kdx = kdx + 4
154: end do
155: nn = 4*(jend - jstart + 1) - 1
156: do ii = 0, 3
157: do jj = 0, nn
158: s(ii) = s(ii) - a(ax + 4*jj + ii)*w(jj)
159: end do
160: end do
162: ax = 16*adiag(i)
163: x(idx) = a(ax + 0)*s(0) + a(ax + 4)*s(1) + a(ax + 8)*s(2) + a(ax + 12)*s(3)
164: x(idx + 1) = a(ax + 1)*s(0) + a(ax + 5)*s(1) + a(ax + 9)*s(2) + a(ax + 13)*s(3)
165: x(idx + 2) = a(ax + 2)*s(0) + a(ax + 6)*s(1) + a(ax + 10)*s(2) + a(ax + 14)*s(3)
166: x(idx + 3) = a(ax + 3)*s(0) + a(ax + 7)*s(1) + a(ax + 11)*s(2) + a(ax + 15)*s(3)
167: idx = idx - 4
168: end do
169: end subroutine FortranSolveBAIJ4