Actual source code: fsolvebaij.F
petsc-3.7.3 2016-08-01
1: !
2: !
3: ! Fortran kernel for sparse triangular solve in the BAIJ matrix format
4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6: !
7: #include <petsc/finclude/petscsysdef.h>
8: !
10: subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
11: implicit none
12: MatScalar a(0:*)
13: PetscScalar x(0:*)
14: PetscScalar b(0:*)
15: PetscInt n
16: PetscInt ai(0:*)
17: PetscInt aj(0:*)
18: PetscInt adiag(0:*)
20: PetscInt i,j,jstart,jend
21: PetscInt idx,ax,jdx
22: PetscScalar s1,s2,s3,s4
23: PetscScalar x1,x2,x3,x4
24: !
25: ! Forward Solve
26: !
27: PETSC_AssertAlignx(16,a(1))
28: PETSC_AssertAlignx(16,x(1))
29: PETSC_AssertAlignx(16,b(1))
30: PETSC_AssertAlignx(16,ai(1))
31: PETSC_AssertAlignx(16,aj(1))
32: PETSC_AssertAlignx(16,adiag(1))
34: x(0) = b(0)
35: x(1) = b(1)
36: x(2) = b(2)
37: x(3) = b(3)
38: idx = 0
39: do 20 i=1,n-1
40: jstart = ai(i)
41: jend = adiag(i) - 1
42: ax = 16*jstart
43: idx = idx + 4
44: s1 = b(idx)
45: s2 = b(idx+1)
46: s3 = b(idx+2)
47: s4 = b(idx+3)
48: do 30 j=jstart,jend
49: jdx = 4*aj(j)
51: x1 = x(jdx)
52: x2 = x(jdx+1)
53: x3 = x(jdx+2)
54: x4 = x(jdx+3)
55: s1 = s1-(a(ax)*x1 +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
56: s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
57: s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
58: s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
59: ax = ax + 16
60: 30 continue
61: x(idx) = s1
62: x(idx+1) = s2
63: x(idx+2) = s3
64: x(idx+3) = s4
65: 20 continue
68: !
69: ! Backward solve the upper triangular
70: !
71: do 40 i=n-1,0,-1
72: jstart = adiag(i) + 1
73: jend = ai(i+1) - 1
74: ax = 16*jstart
75: s1 = x(idx)
76: s2 = x(idx+1)
77: s3 = x(idx+2)
78: s4 = x(idx+3)
79: do 50 j=jstart,jend
80: jdx = 4*aj(j)
81: x1 = x(jdx)
82: x2 = x(jdx+1)
83: x3 = x(jdx+2)
84: x4 = x(jdx+3)
85: s1 = s1-(a(ax)*x1 +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
86: s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
87: s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
88: s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
89: ax = ax + 16
90: 50 continue
91: ax = 16*adiag(i)
92: x(idx) = a(ax)*s1 +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
93: x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
94: x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
95: x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
96: idx = idx - 4
97: 40 continue
98: return
99: end
101: !
102: ! version that calls BLAS 2 operation for each row block
103: !
104: subroutine FortranSolveBAIJ4BLAS(n,x,ai,aj,adiag,a,b,w)
105: implicit none
106: MatScalar a(0:*),w(0:*)
107: PetscScalar x(0:*),b(0:*)
108: PetscInt n,ai(0:*),aj(0:*),adiag(0:*)
110: PetscInt i,j,jstart,jend,idx,ax,jdx,kdx
111: MatScalar s(0:3)
112: integer align7
113: !
114: ! Forward Solve
115: !
118: PETSC_AssertAlignx(16,a(1))
119: PETSC_AssertAlignx(16,w(1))
120: PETSC_AssertAlignx(16,x(1))
121: PETSC_AssertAlignx(16,b(1))
122: PETSC_AssertAlignx(16,ai(1))
123: PETSC_AssertAlignx(16,aj(1))
124: PETSC_AssertAlignx(16,adiag(1))
126: x(0) = b(0)
127: x(1) = b(1)
128: x(2) = b(2)
129: x(3) = b(3)
130: idx = 0
131: do 20 i=1,n-1
132: !
133: ! Pack required part of vector into work array
134: !
135: kdx = 0
136: jstart = ai(i)
137: jend = adiag(i) - 1
138: if (jend - jstart .ge. 500) then
139: write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
140: endif
141: do 30 j=jstart,jend
143: jdx = 4*aj(j)
145: w(kdx) = x(jdx)
146: w(kdx+1) = x(jdx+1)
147: w(kdx+2) = x(jdx+2)
148: w(kdx+3) = x(jdx+3)
149: kdx = kdx + 4
150: 30 continue
152: ax = 16*jstart
153: idx = idx + 4
154: s(0) = b(idx)
155: s(1) = b(idx+1)
156: s(2) = b(idx+2)
157: s(3) = b(idx+3)
158: !
159: ! s = s - a(ax:)*w
160: !
161: call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)
162: ! call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)
164: x(idx) = s(0)
165: x(idx+1) = s(1)
166: x(idx+2) = s(2)
167: x(idx+3) = s(3)
168: 20 continue
170: !
171: ! Backward solve the upper triangular
172: !
173: do 40 i=n-1,0,-1
174: jstart = adiag(i) + 1
175: jend = ai(i+1) - 1
176: ax = 16*jstart
177: s(0) = x(idx)
178: s(1) = x(idx+1)
179: s(2) = x(idx+2)
180: s(3) = x(idx+3)
181: !
182: ! Pack each chunk of vector needed
183: !
184: kdx = 0
185: if (jend - jstart .ge. 500) then
186: write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
187: endif
188: do 50 j=jstart,jend
189: jdx = 4*aj(j)
190: w(kdx) = x(jdx)
191: w(kdx+1) = x(jdx+1)
192: w(kdx+2) = x(jdx+2)
193: w(kdx+3) = x(jdx+3)
194: kdx = kdx + 4
195: 50 continue
196: ! call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)
197: call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)
199: ax = 16*adiag(i)
200: x(idx) = a(ax)*s(0) +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
201: x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
202: x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
203: x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
204: idx = idx - 4
205: 40 continue
207: return
208: end
211: !
212: ! version that does not call BLAS 2 operation for each row block
213: !
214: subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
215: implicit none
216: MatScalar a(0:*)
217: PetscScalar x(0:*),b(0:*),w(0:*)
218: PetscInt n,ai(0:*),aj(0:*),adiag(0:*)
219: PetscInt ii,jj,i,j
221: PetscInt jstart,jend,idx,ax,jdx,kdx,nn
222: PetscScalar s(0:3)
224: !
225: ! Forward Solve
226: !
228: PETSC_AssertAlignx(16,a(1))
229: PETSC_AssertAlignx(16,w(1))
230: PETSC_AssertAlignx(16,x(1))
231: PETSC_AssertAlignx(16,b(1))
232: PETSC_AssertAlignx(16,ai(1))
233: PETSC_AssertAlignx(16,aj(1))
234: PETSC_AssertAlignx(16,adiag(1))
236: x(0) = b(0)
237: x(1) = b(1)
238: x(2) = b(2)
239: x(3) = b(3)
240: idx = 0
241: do 20 i=1,n-1
242: !
243: ! Pack required part of vector into work array
244: !
245: kdx = 0
246: jstart = ai(i)
247: jend = adiag(i) - 1
248: if (jend - jstart .ge. 500) then
249: write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
250: endif
251: do 30 j=jstart,jend
253: jdx = 4*aj(j)
255: w(kdx) = x(jdx)
256: w(kdx+1) = x(jdx+1)
257: w(kdx+2) = x(jdx+2)
258: w(kdx+3) = x(jdx+3)
259: kdx = kdx + 4
260: 30 continue
262: ax = 16*jstart
263: idx = idx + 4
264: s(0) = b(idx)
265: s(1) = b(idx+1)
266: s(2) = b(idx+2)
267: s(3) = b(idx+3)
268: !
269: ! s = s - a(ax:)*w
270: !
271: nn = 4*(jend - jstart + 1) - 1
272: do 100, ii=0,3
273: do 110, jj=0,nn
274: s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
275: 110 continue
276: 100 continue
278: x(idx) = s(0)
279: x(idx+1) = s(1)
280: x(idx+2) = s(2)
281: x(idx+3) = s(3)
282: 20 continue
284: !
285: ! Backward solve the upper triangular
286: !
287: do 40 i=n-1,0,-1
288: jstart = adiag(i) + 1
289: jend = ai(i+1) - 1
290: ax = 16*jstart
291: s(0) = x(idx)
292: s(1) = x(idx+1)
293: s(2) = x(idx+2)
294: s(3) = x(idx+3)
295: !
296: ! Pack each chunk of vector needed
297: !
298: kdx = 0
299: if (jend - jstart .ge. 500) then
300: write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
301: endif
302: do 50 j=jstart,jend
303: jdx = 4*aj(j)
304: w(kdx) = x(jdx)
305: w(kdx+1) = x(jdx+1)
306: w(kdx+2) = x(jdx+2)
307: w(kdx+3) = x(jdx+3)
308: kdx = kdx + 4
309: 50 continue
310: nn = 4*(jend - jstart + 1) - 1
311: do 200, ii=0,3
312: do 210, jj=0,nn
313: s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
314: 210 continue
315: 200 continue
317: ax = 16*adiag(i)
318: x(idx) = a(ax)*s(0) +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
319: x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
320: x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
321: x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
322: idx = idx - 4
323: 40 continue
325: return
326: end