Actual source code: fsolvebaij.F

petsc-3.6.1 2015-08-06
Report Typos and Errors
  1: !
  2: !
  3: !    Fortran kernel for sparse triangular solve in the BAIJ matrix format
  4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
  5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
  6: !
  7: #include <petsc/finclude/petscsysdef.h>
  8: !

 10:       subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
 11:       implicit none
 12:       MatScalar   a(0:*)
 13:       PetscScalar x(0:*)
 14:       PetscScalar b(0:*)
 15:       PetscInt    n
 16:       PetscInt    ai(0:*)
 17:       PetscInt    aj(0:*)
 18:       PetscInt    adiag(0:*)

 20:       PetscInt    i,j,jstart,jend
 21:       PetscInt    idx,ax,jdx
 22:       PetscScalar s1,s2,s3,s4
 23:       PetscScalar x1,x2,x3,x4
 24: !
 25: !     Forward Solve
 26: !
 27:       PETSC_AssertAlignx(16,a(1))
 28:       PETSC_AssertAlignx(16,x(1))
 29:       PETSC_AssertAlignx(16,b(1))
 30:       PETSC_AssertAlignx(16,ai(1))
 31:       PETSC_AssertAlignx(16,aj(1))
 32:       PETSC_AssertAlignx(16,adiag(1))

 34:          x(0) = b(0)
 35:          x(1) = b(1)
 36:          x(2) = b(2)
 37:          x(3) = b(3)
 38:          idx  = 0
 39:          do 20 i=1,n-1
 40:             jstart = ai(i)
 41:             jend   = adiag(i) - 1
 42:             ax    = 16*jstart
 43:             idx    = idx + 4
 44:             s1     = b(idx)
 45:             s2     = b(idx+1)
 46:             s3     = b(idx+2)
 47:             s4     = b(idx+3)
 48:             do 30 j=jstart,jend
 49:               jdx   = 4*aj(j)

 51:               x1    = x(jdx)
 52:               x2    = x(jdx+1)
 53:               x3    = x(jdx+2)
 54:               x4    = x(jdx+3)
 55:               s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 56:               s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 57:               s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 58:               s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 59:               ax = ax + 16
 60:  30         continue
 61:             x(idx)   = s1
 62:             x(idx+1) = s2
 63:             x(idx+2) = s3
 64:             x(idx+3) = s4
 65:  20      continue


 68: !
 69: !     Backward solve the upper triangular
 70: !
 71:          do 40 i=n-1,0,-1
 72:             jstart  = adiag(i) + 1
 73:             jend    = ai(i+1) - 1
 74:             ax     = 16*jstart
 75:             s1      = x(idx)
 76:             s2      = x(idx+1)
 77:             s3      = x(idx+2)
 78:             s4      = x(idx+3)
 79:             do 50 j=jstart,jend
 80:               jdx   = 4*aj(j)
 81:               x1    = x(jdx)
 82:               x2    = x(jdx+1)
 83:               x3    = x(jdx+2)
 84:               x4    = x(jdx+3)
 85:               s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 86:               s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 87:               s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 88:               s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 89:               ax = ax + 16
 90:  50         continue
 91:             ax      = 16*adiag(i)
 92:             x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
 93:             x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
 94:             x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
 95:             x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
 96:             idx      = idx - 4
 97:  40      continue
 98:       return
 99:       end

101: !
102: !   version that calls BLAS 2 operation for each row block
103: !
104:       subroutine FortranSolveBAIJ4BLAS(n,x,ai,aj,adiag,a,b,w)
105:       implicit none
106:       MatScalar   a(0:*),w(0:*)
107:       PetscScalar x(0:*),b(0:*)
108:       PetscInt n,ai(0:*),aj(0:*),adiag(0:*)

110:       PetscInt i,j,jstart,jend,idx,ax,jdx,kdx
111:       MatScalar s(0:3)
112:       integer   align7
113: !
114: !     Forward Solve
115: !


118:       PETSC_AssertAlignx(16,a(1))
119:       PETSC_AssertAlignx(16,w(1))
120:       PETSC_AssertAlignx(16,x(1))
121:       PETSC_AssertAlignx(16,b(1))
122:       PETSC_AssertAlignx(16,ai(1))
123:       PETSC_AssertAlignx(16,aj(1))
124:       PETSC_AssertAlignx(16,adiag(1))

126:       x(0) = b(0)
127:       x(1) = b(1)
128:       x(2) = b(2)
129:       x(3) = b(3)
130:       idx  = 0
131:       do 20 i=1,n-1
132: !
133: !        Pack required part of vector into work array
134: !
135:          kdx    = 0
136:          jstart = ai(i)
137:          jend   = adiag(i) - 1
138:          if (jend - jstart .ge. 500) then
139:            write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
140:          endif
141:          do 30 j=jstart,jend

143:            jdx       = 4*aj(j)

145:            w(kdx)    = x(jdx)
146:            w(kdx+1)  = x(jdx+1)
147:            w(kdx+2)  = x(jdx+2)
148:            w(kdx+3)  = x(jdx+3)
149:            kdx       = kdx + 4
150:  30      continue

152:          ax      = 16*jstart
153:          idx      = idx + 4
154:          s(0)     = b(idx)
155:          s(1)     = b(idx+1)
156:          s(2)     = b(idx+2)
157:          s(3)     = b(idx+3)
158: !
159: !    s = s - a(ax:)*w
160: !
161:          call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)
162: !         call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)

164:          x(idx)   = s(0)
165:          x(idx+1) = s(1)
166:          x(idx+2) = s(2)
167:          x(idx+3) = s(3)
168:  20   continue

170: !
171: !     Backward solve the upper triangular
172: !
173:       do 40 i=n-1,0,-1
174:          jstart    = adiag(i) + 1
175:          jend      = ai(i+1) - 1
176:          ax       = 16*jstart
177:          s(0)      = x(idx)
178:          s(1)      = x(idx+1)
179:          s(2)      = x(idx+2)
180:          s(3)      = x(idx+3)
181: !
182: !   Pack each chunk of vector needed
183: !
184:          kdx = 0
185:          if (jend - jstart .ge. 500) then
186:            write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
187:          endif
188:          do 50 j=jstart,jend
189:            jdx      = 4*aj(j)
190:            w(kdx)   = x(jdx)
191:            w(kdx+1) = x(jdx+1)
192:            w(kdx+2) = x(jdx+2)
193:            w(kdx+3) = x(jdx+3)
194:            kdx      = kdx + 4
195:  50      continue
196: !         call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)
197:          call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)

199:          ax      = 16*adiag(i)
200:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
201:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
202:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
203:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
204:          idx     = idx - 4
205:  40   continue

207:       return
208:       end


211: !
212: !   version that does not call BLAS 2 operation for each row block
213: !
214:       subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
215:       implicit none
216:       MatScalar   a(0:*)
217:       PetscScalar x(0:*),b(0:*),w(0:*)
218:       PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)
219:       PetscInt  ii,jj,i,j

221:       PetscInt  jstart,jend,idx,ax,jdx,kdx,nn
222:       PetscScalar s(0:3)

224: !
225: !     Forward Solve
226: !

228:       PETSC_AssertAlignx(16,a(1))
229:       PETSC_AssertAlignx(16,w(1))
230:       PETSC_AssertAlignx(16,x(1))
231:       PETSC_AssertAlignx(16,b(1))
232:       PETSC_AssertAlignx(16,ai(1))
233:       PETSC_AssertAlignx(16,aj(1))
234:       PETSC_AssertAlignx(16,adiag(1))

236:       x(0) = b(0)
237:       x(1) = b(1)
238:       x(2) = b(2)
239:       x(3) = b(3)
240:       idx  = 0
241:       do 20 i=1,n-1
242: !
243: !        Pack required part of vector into work array
244: !
245:          kdx    = 0
246:          jstart = ai(i)
247:          jend   = adiag(i) - 1
248:          if (jend - jstart .ge. 500) then
249:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
250:          endif
251:          do 30 j=jstart,jend

253:            jdx       = 4*aj(j)

255:            w(kdx)    = x(jdx)
256:            w(kdx+1)  = x(jdx+1)
257:            w(kdx+2)  = x(jdx+2)
258:            w(kdx+3)  = x(jdx+3)
259:            kdx       = kdx + 4
260:  30      continue

262:          ax       = 16*jstart
263:          idx      = idx + 4
264:          s(0)     = b(idx)
265:          s(1)     = b(idx+1)
266:          s(2)     = b(idx+2)
267:          s(3)     = b(idx+3)
268: !
269: !    s = s - a(ax:)*w
270: !
271:          nn = 4*(jend - jstart + 1) - 1
272:          do 100, ii=0,3
273:            do 110, jj=0,nn
274:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
275:  110       continue
276:  100     continue

278:          x(idx)   = s(0)
279:          x(idx+1) = s(1)
280:          x(idx+2) = s(2)
281:          x(idx+3) = s(3)
282:  20   continue

284: !
285: !     Backward solve the upper triangular
286: !
287:       do 40 i=n-1,0,-1
288:          jstart    = adiag(i) + 1
289:          jend      = ai(i+1) - 1
290:          ax        = 16*jstart
291:          s(0)      = x(idx)
292:          s(1)      = x(idx+1)
293:          s(2)      = x(idx+2)
294:          s(3)      = x(idx+3)
295: !
296: !   Pack each chunk of vector needed
297: !
298:          kdx = 0
299:          if (jend - jstart .ge. 500) then
300:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
301:          endif
302:          do 50 j=jstart,jend
303:            jdx      = 4*aj(j)
304:            w(kdx)   = x(jdx)
305:            w(kdx+1) = x(jdx+1)
306:            w(kdx+2) = x(jdx+2)
307:            w(kdx+3) = x(jdx+3)
308:            kdx      = kdx + 4
309:  50      continue
310:          nn = 4*(jend - jstart + 1) - 1
311:          do 200, ii=0,3
312:            do 210, jj=0,nn
313:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
314:  210       continue
315:  200     continue

317:          ax      = 16*adiag(i)
318:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
319:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
320:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
321:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
322:          idx     = idx - 4
323:  40   continue

325:       return
326:       end