Actual source code: fsolvebaij.F90

  1: !
  2: !
  3: !    Fortran kernel for sparse triangular solve in the BAIJ matrix format
  4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
  5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
  6: !
  7: #include <petsc/finclude/petscsys.h>
  8: !

 10: pure subroutine FortranSolveBAIJ4Unroll(n, x, ai, aj, adiag, a, b)
 11:   use, intrinsic :: ISO_C_binding
 12:   implicit none(type, external)
 13:   MatScalar, intent(in) :: a(0:*)
 14:   PetscScalar, intent(inout) :: x(0:*)
 15:   PetscScalar, intent(in) :: b(0:*)
 16:   PetscInt, intent(in) :: n
 17:   PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)

 19:   PetscInt :: i, j, jstart, jend
 20:   PetscInt :: idx, ax, jdx
 21:   PetscScalar :: s(0:3)

 23:   PETSC_AssertAlignx(16, a(1))
 24:   PETSC_AssertAlignx(16, x(1))
 25:   PETSC_AssertAlignx(16, b(1))
 26:   PETSC_AssertAlignx(16, ai(1))
 27:   PETSC_AssertAlignx(16, aj(1))
 28:   PETSC_AssertAlignx(16, adiag(1))

 30:   !
 31:   ! Forward Solve
 32:   !
 33:   x(0:3) = b(0:3)
 34:   idx = 0
 35:   do i = 1, n - 1
 36:     jstart = ai(i)
 37:     jend = adiag(i) - 1
 38:     ax = 16*jstart
 39:     idx = idx + 4
 40:     s(0:3) = b(idx + 0:idx + 3)
 41:     do j = jstart, jend
 42:       jdx = 4*aj(j)

 44:       s(0) = s(0) - (a(ax + 0)*x(jdx + 0) + a(ax + 4)*x(jdx + 1) + a(ax + 8)*x(jdx + 2) + a(ax + 12)*x(jdx + 3))
 45:       s(1) = s(1) - (a(ax + 1)*x(jdx + 0) + a(ax + 5)*x(jdx + 1) + a(ax + 9)*x(jdx + 2) + a(ax + 13)*x(jdx + 3))
 46:       s(2) = s(2) - (a(ax + 2)*x(jdx + 0) + a(ax + 6)*x(jdx + 1) + a(ax + 10)*x(jdx + 2) + a(ax + 14)*x(jdx + 3))
 47:       s(3) = s(3) - (a(ax + 3)*x(jdx + 0) + a(ax + 7)*x(jdx + 1) + a(ax + 11)*x(jdx + 2) + a(ax + 15)*x(jdx + 3))
 48:       ax = ax + 16
 49:     end do
 50:     x(idx + 0:idx + 3) = s(0:3)
 51:   end do

 53:   !
 54:   ! Backward solve the upper triangular
 55:   !
 56:   do i = n - 1, 0, -1
 57:     jstart = adiag(i) + 1
 58:     jend = ai(i + 1) - 1
 59:     ax = 16*jstart
 60:     s(0:3) = x(idx + 0:idx + 3)
 61:     do j = jstart, jend
 62:       jdx = 4*aj(j)
 63:       s(0) = s(0) - (a(ax + 0)*x(jdx + 0) + a(ax + 4)*x(jdx + 1) + a(ax + 8)*x(jdx + 2) + a(ax + 12)*x(jdx + 3))
 64:       s(1) = s(1) - (a(ax + 1)*x(jdx + 0) + a(ax + 5)*x(jdx + 1) + a(ax + 9)*x(jdx + 2) + a(ax + 13)*x(jdx + 3))
 65:       s(2) = s(2) - (a(ax + 2)*x(jdx + 0) + a(ax + 6)*x(jdx + 1) + a(ax + 10)*x(jdx + 2) + a(ax + 14)*x(jdx + 3))
 66:       s(3) = s(3) - (a(ax + 3)*x(jdx + 0) + a(ax + 7)*x(jdx + 1) + a(ax + 11)*x(jdx + 2) + a(ax + 15)*x(jdx + 3))
 67:       ax = ax + 16
 68:     end do
 69:     ax = 16*adiag(i)
 70:     x(idx + 0) = a(ax + 0)*s(0) + a(ax + 4)*s(1) + a(ax + 8)*s(2) + a(ax + 12)*s(3)
 71:     x(idx + 1) = a(ax + 1)*s(0) + a(ax + 5)*s(1) + a(ax + 9)*s(2) + a(ax + 13)*s(3)
 72:     x(idx + 2) = a(ax + 2)*s(0) + a(ax + 6)*s(1) + a(ax + 10)*s(2) + a(ax + 14)*s(3)
 73:     x(idx + 3) = a(ax + 3)*s(0) + a(ax + 7)*s(1) + a(ax + 11)*s(2) + a(ax + 15)*s(3)
 74:     idx = idx - 4
 75:   end do
 76: end subroutine FortranSolveBAIJ4Unroll

 78: !   version that does not call BLAS 2 operation for each row block
 79: !
 80: pure subroutine FortranSolveBAIJ4(n, x, ai, aj, adiag, a, b, w)
 81:   use, intrinsic :: ISO_C_binding
 82:   implicit none
 83:   MatScalar, intent(in) :: a(0:*)
 84:   PetscScalar, intent(inout) :: x(0:*), w(0:*)
 85:   PetscScalar, intent(in) :: b(0:*)
 86:   PetscInt, intent(in) :: n
 87:   PetscInt, intent(in) :: ai(0:*), aj(0:*), adiag(0:*)

 89:   PetscInt :: ii, jj, i, j
 90:   PetscInt :: jstart, jend, idx, ax, jdx, kdx, nn
 91:   PetscScalar :: s(0:3)

 93:   PETSC_AssertAlignx(16, a(1))
 94:   PETSC_AssertAlignx(16, w(1))
 95:   PETSC_AssertAlignx(16, x(1))
 96:   PETSC_AssertAlignx(16, b(1))
 97:   PETSC_AssertAlignx(16, ai(1))
 98:   PETSC_AssertAlignx(16, aj(1))
 99:   PETSC_AssertAlignx(16, adiag(1))
100:   !
101:   !     Forward Solve
102:   !
103:   x(0:3) = b(0:3)
104:   idx = 0
105:   do i = 1, n - 1
106:     !
107:     ! Pack required part of vector into work array
108:     !
109:     kdx = 0
110:     jstart = ai(i)
111:     jend = adiag(i) - 1

113:     if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()'

115:     do j = jstart, jend
116:       jdx = 4*aj(j)
117:       w(kdx:kdx + 3) = x(jdx:jdx + 3)
118:       kdx = kdx + 4
119:     end do

121:     ax = 16*jstart
122:     idx = idx + 4
123:     s(0:3) = b(idx:idx + 3)
124:     !
125:     !    s = s - a(ax:)*w
126:     !
127:     nn = 4*(jend - jstart + 1) - 1
128:     do ii = 0, 3
129:       do jj = 0, nn
130:         s(ii) = s(ii) - a(ax + 4*jj + ii)*w(jj)
131:       end do
132:     end do

134:     x(idx:idx + 3) = s(0:3)
135:   end do
136:   !
137:   ! Backward solve the upper triangular
138:   !
139:   do i = n - 1, 0, -1
140:     jstart = adiag(i) + 1
141:     jend = ai(i + 1) - 1
142:     ax = 16*jstart
143:     s(0:3) = x(idx:idx + 3)
144:     !
145:     !   Pack each chunk of vector needed
146:     !
147:     kdx = 0
148:     if (jend - jstart >= 500) error stop 'Overflowing vector FortranSolveBAIJ4()'

150:     do j = jstart, jend
151:       jdx = 4*aj(j)
152:       w(kdx:kdx + 3) = x(jdx:jdx + 3)
153:       kdx = kdx + 4
154:     end do
155:     nn = 4*(jend - jstart + 1) - 1
156:     do ii = 0, 3
157:       do jj = 0, nn
158:         s(ii) = s(ii) - a(ax + 4*jj + ii)*w(jj)
159:       end do
160:     end do

162:     ax = 16*adiag(i)
163:     x(idx) = a(ax + 0)*s(0) + a(ax + 4)*s(1) + a(ax + 8)*s(2) + a(ax + 12)*s(3)
164:     x(idx + 1) = a(ax + 1)*s(0) + a(ax + 5)*s(1) + a(ax + 9)*s(2) + a(ax + 13)*s(3)
165:     x(idx + 2) = a(ax + 2)*s(0) + a(ax + 6)*s(1) + a(ax + 10)*s(2) + a(ax + 14)*s(3)
166:     x(idx + 3) = a(ax + 3)*s(0) + a(ax + 7)*s(1) + a(ax + 11)*s(2) + a(ax + 15)*s(3)
167:     idx = idx - 4
168:   end do
169: end subroutine FortranSolveBAIJ4