Actual source code: baijfact81.c
petsc-3.12.5 2020-03-29
2: /*
3: Factorization code for BAIJ format.
4: */
5: #include <../src/mat/impls/baij/seq/baij.h>
6: #include <petsc/private/kernels/blockinvert.h>
7: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
8: #include <immintrin.h>
9: #endif
10: /*
11: Version for when blocks are 9 by 9
12: */
13: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
14: PetscErrorCode MatLUFactorNumeric_SeqBAIJ_9_NaturalOrdering(Mat B,Mat A,const MatFactorInfo *info)
15: {
16: Mat C =B;
17: Mat_SeqBAIJ *a=(Mat_SeqBAIJ*)A->data,*b=(Mat_SeqBAIJ*)C->data;
19: PetscInt i,j,k,nz,nzL,row;
20: const PetscInt n=a->mbs,*ai=a->i,*aj=a->j,*bi=b->i,*bj=b->j;
21: const PetscInt *ajtmp,*bjtmp,*bdiag=b->diag,*pj,bs2=a->bs2;
22: MatScalar *rtmp,*pc,*mwork,*v,*pv,*aa=a->a;
23: PetscInt flg;
24: PetscReal shift = info->shiftamount;
25: PetscBool allowzeropivot,zeropivotdetected;
28: allowzeropivot = PetscNot(A->erroriffailure);
30: /* generate work space needed by the factorization */
31: PetscMalloc2(bs2*n,&rtmp,bs2,&mwork);
32: PetscArrayzero(rtmp,bs2*n);
34: for (i=0; i<n; i++) {
35: /* zero rtmp */
36: /* L part */
37: nz = bi[i+1] - bi[i];
38: bjtmp = bj + bi[i];
39: for (j=0; j<nz; j++) {
40: PetscArrayzero(rtmp+bs2*bjtmp[j],bs2);
41: }
43: /* U part */
44: nz = bdiag[i] - bdiag[i+1];
45: bjtmp = bj + bdiag[i+1]+1;
46: for (j=0; j<nz; j++) {
47: PetscArrayzero(rtmp+bs2*bjtmp[j],bs2);
48: }
50: /* load in initial (unfactored row) */
51: nz = ai[i+1] - ai[i];
52: ajtmp = aj + ai[i];
53: v = aa + bs2*ai[i];
54: for (j=0; j<nz; j++) {
55: PetscArraycpy(rtmp+bs2*ajtmp[j],v+bs2*j,bs2);
56: }
58: /* elimination */
59: bjtmp = bj + bi[i];
60: nzL = bi[i+1] - bi[i];
61: for (k=0; k < nzL; k++) {
62: row = bjtmp[k];
63: pc = rtmp + bs2*row;
64: for (flg=0,j=0; j<bs2; j++) {
65: if (pc[j]!=0.0) {
66: flg = 1;
67: break;
68: }
69: }
70: if (flg) {
71: pv = b->a + bs2*bdiag[row];
72: /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
73: PetscKernel_A_gets_A_times_B_9(pc,pv,mwork);
75: pj = b->j + bdiag[row+1]+1; /* begining of U(row,:) */
76: pv = b->a + bs2*(bdiag[row+1]+1);
77: nz = bdiag[row] - bdiag[row+1] - 1; /* num of entries inU(row,:), excluding diag */
78: for (j=0; j<nz; j++) {
79: /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
80: /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
81: v = rtmp + bs2*pj[j];
82: PetscKernel_A_gets_A_minus_B_times_C_9(v,pc,pv);
83: /* pv incremented in PetscKernel_A_gets_A_minus_B_times_C_9 */
84: }
85: PetscLogFlops(1458*nz+1377); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
86: }
87: }
89: /* finished row so stick it into b->a */
90: /* L part */
91: pv = b->a + bs2*bi[i];
92: pj = b->j + bi[i];
93: nz = bi[i+1] - bi[i];
94: for (j=0; j<nz; j++) {
95: PetscArraycpy(pv+bs2*j,rtmp+bs2*pj[j],bs2);
96: }
98: /* Mark diagonal and invert diagonal for simplier triangular solves */
99: pv = b->a + bs2*bdiag[i];
100: pj = b->j + bdiag[i];
101: PetscArraycpy(pv,rtmp+bs2*pj[0],bs2);
102: PetscKernel_A_gets_inverse_A_9(pv,shift,allowzeropivot,&zeropivotdetected);
103: if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
105: /* U part */
106: pv = b->a + bs2*(bdiag[i+1]+1);
107: pj = b->j + bdiag[i+1]+1;
108: nz = bdiag[i] - bdiag[i+1] - 1;
109: for (j=0; j<nz; j++) {
110: PetscArraycpy(pv+bs2*j,rtmp+bs2*pj[j],bs2);
111: }
112: }
113: PetscFree2(rtmp,mwork);
115: C->ops->solve = MatSolve_SeqBAIJ_9_NaturalOrdering;
116: C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_N;
117: C->assembled = PETSC_TRUE;
119: PetscLogFlops(1.333333333333*9*9*9*n); /* from inverting diagonal blocks */
120: return(0);
121: }
123: PetscErrorCode MatSolve_SeqBAIJ_9_NaturalOrdering(Mat A,Vec bb,Vec xx)
124: {
125: Mat_SeqBAIJ *a=(Mat_SeqBAIJ*)A->data;
127: const PetscInt *ai=a->i,*aj=a->j,*adiag=a->diag,*vi;
128: PetscInt i,k,n=a->mbs;
129: PetscInt nz,bs=A->rmap->bs,bs2=a->bs2;
130: const MatScalar *aa=a->a,*v;
131: PetscScalar *x,*s,*t,*ls;
132: const PetscScalar *b;
133: __m256d a0,a1,a2,a3,a4,a5,w0,w1,w2,w3,s0,s1,s2,v0,v1,v2,v3;
136: VecGetArrayRead(bb,&b);
137: VecGetArray(xx,&x);
138: t = a->solve_work;
140: /* forward solve the lower triangular */
141: PetscArraycpy(t,b,bs); /* copy 1st block of b to t */
143: for (i=1; i<n; i++) {
144: v = aa + bs2*ai[i];
145: vi = aj + ai[i];
146: nz = ai[i+1] - ai[i];
147: s = t + bs*i;
148: PetscArraycpy(s,b+bs*i,bs); /* copy i_th block of b to t */
150: __m256d s0,s1,s2;
151: s0 = _mm256_loadu_pd(s+0);
152: s1 = _mm256_loadu_pd(s+4);
153: s2 = _mm256_maskload_pd(s+8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
155: for (k=0;k<nz;k++) {
157: w0 = _mm256_set1_pd((t+bs*vi[k])[0]);
158: a0 = _mm256_loadu_pd(&v[ 0]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
159: a1 = _mm256_loadu_pd(&v[ 4]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
160: a2 = _mm256_loadu_pd(&v[ 8]); s2 = _mm256_fnmadd_pd(a2,w0,s2);
162: w1 = _mm256_set1_pd((t+bs*vi[k])[1]);
163: a3 = _mm256_loadu_pd(&v[ 9]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
164: a4 = _mm256_loadu_pd(&v[13]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
165: a5 = _mm256_loadu_pd(&v[17]); s2 = _mm256_fnmadd_pd(a5,w1,s2);
167: w2 = _mm256_set1_pd((t+bs*vi[k])[2]);
168: a0 = _mm256_loadu_pd(&v[18]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
169: a1 = _mm256_loadu_pd(&v[22]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
170: a2 = _mm256_loadu_pd(&v[26]); s2 = _mm256_fnmadd_pd(a2,w2,s2);
172: w3 = _mm256_set1_pd((t+bs*vi[k])[3]);
173: a3 = _mm256_loadu_pd(&v[27]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
174: a4 = _mm256_loadu_pd(&v[31]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
175: a5 = _mm256_loadu_pd(&v[35]); s2 = _mm256_fnmadd_pd(a5,w3,s2);
177: w0 = _mm256_set1_pd((t+bs*vi[k])[4]);
178: a0 = _mm256_loadu_pd(&v[36]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
179: a1 = _mm256_loadu_pd(&v[40]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
180: a2 = _mm256_loadu_pd(&v[44]); s2 = _mm256_fnmadd_pd(a2,w0,s2);
182: w1 = _mm256_set1_pd((t+bs*vi[k])[5]);
183: a3 = _mm256_loadu_pd(&v[45]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
184: a4 = _mm256_loadu_pd(&v[49]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
185: a5 = _mm256_loadu_pd(&v[53]); s2 = _mm256_fnmadd_pd(a5,w1,s2);
187: w2 = _mm256_set1_pd((t+bs*vi[k])[6]);
188: a0 = _mm256_loadu_pd(&v[54]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
189: a1 = _mm256_loadu_pd(&v[58]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
190: a2 = _mm256_loadu_pd(&v[62]); s2 = _mm256_fnmadd_pd(a2,w2,s2);
192: w3 = _mm256_set1_pd((t+bs*vi[k])[7]);
193: a3 = _mm256_loadu_pd(&v[63]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
194: a4 = _mm256_loadu_pd(&v[67]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
195: a5 = _mm256_loadu_pd(&v[71]); s2 = _mm256_fnmadd_pd(a5,w3,s2);
197: w0 = _mm256_set1_pd((t+bs*vi[k])[8]);
198: a0 = _mm256_loadu_pd(&v[72]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
199: a1 = _mm256_loadu_pd(&v[76]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
200: a2 = _mm256_maskload_pd(v+80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
201: s2 = _mm256_fnmadd_pd(a2,w0,s2);
202: v += bs2;
203: }
204: _mm256_storeu_pd(&s[0], s0);
205: _mm256_storeu_pd(&s[4], s1);
206: _mm256_maskstore_pd(&s[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), s2);
207: }
209: /* backward solve the upper triangular */
210: ls = a->solve_work + A->cmap->n;
211: for (i=n-1; i>=0; i--) {
212: v = aa + bs2*(adiag[i+1]+1);
213: vi = aj + adiag[i+1]+1;
214: nz = adiag[i] - adiag[i+1]-1;
215: PetscArraycpy(ls,t+i*bs,bs);
217: s0 = _mm256_loadu_pd(ls+0);
218: s1 = _mm256_loadu_pd(ls+4);
219: s2 = _mm256_maskload_pd(ls+8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
221: for (k=0; k<nz; k++) {
223: w0 = _mm256_set1_pd((t+bs*vi[k])[0]);
224: a0 = _mm256_loadu_pd(&v[ 0]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
225: a1 = _mm256_loadu_pd(&v[ 4]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
226: a2 = _mm256_loadu_pd(&v[ 8]); s2 = _mm256_fnmadd_pd(a2,w0,s2);
228: /* v += 9; */
229: w1 = _mm256_set1_pd((t+bs*vi[k])[1]);
230: a3 = _mm256_loadu_pd(&v[ 9]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
231: a4 = _mm256_loadu_pd(&v[13]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
232: a5 = _mm256_loadu_pd(&v[17]); s2 = _mm256_fnmadd_pd(a5,w1,s2);
234: /* v += 9; */
235: w2 = _mm256_set1_pd((t+bs*vi[k])[2]);
236: a0 = _mm256_loadu_pd(&v[18]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
237: a1 = _mm256_loadu_pd(&v[22]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
238: a2 = _mm256_loadu_pd(&v[26]); s2 = _mm256_fnmadd_pd(a2,w2,s2);
240: /* v += 9; */
241: w3 = _mm256_set1_pd((t+bs*vi[k])[3]);
242: a3 = _mm256_loadu_pd(&v[27]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
243: a4 = _mm256_loadu_pd(&v[31]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
244: a5 = _mm256_loadu_pd(&v[35]); s2 = _mm256_fnmadd_pd(a5,w3,s2);
246: /* v += 9; */
247: w0 = _mm256_set1_pd((t+bs*vi[k])[4]);
248: a0 = _mm256_loadu_pd(&v[36]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
249: a1 = _mm256_loadu_pd(&v[40]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
250: a2 = _mm256_loadu_pd(&v[44]); s2 = _mm256_fnmadd_pd(a2,w0,s2);
252: /* v += 9; */
253: w1 = _mm256_set1_pd((t+bs*vi[k])[5]);
254: a3 = _mm256_loadu_pd(&v[45]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
255: a4 = _mm256_loadu_pd(&v[49]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
256: a5 = _mm256_loadu_pd(&v[53]); s2 = _mm256_fnmadd_pd(a5,w1,s2);
258: /* v += 9; */
259: w2 = _mm256_set1_pd((t+bs*vi[k])[6]);
260: a0 = _mm256_loadu_pd(&v[54]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
261: a1 = _mm256_loadu_pd(&v[58]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
262: a2 = _mm256_loadu_pd(&v[62]); s2 = _mm256_fnmadd_pd(a2,w2,s2);
264: /* v += 9; */
265: w3 = _mm256_set1_pd((t+bs*vi[k])[7]);
266: a3 = _mm256_loadu_pd(&v[63]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
267: a4 = _mm256_loadu_pd(&v[67]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
268: a5 = _mm256_loadu_pd(&v[71]); s2 = _mm256_fnmadd_pd(a5,w3,s2);
270: /* v += 9; */
271: w0 = _mm256_set1_pd((t+bs*vi[k])[8]);
272: a0 = _mm256_loadu_pd(&v[72]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
273: a1 = _mm256_loadu_pd(&v[76]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
274: a2 = _mm256_maskload_pd(v+80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
275: s2 = _mm256_fnmadd_pd(a2,w0,s2);
276: v += bs2;
277: }
279: _mm256_storeu_pd(&ls[0], s0); _mm256_storeu_pd(&ls[4], s1); _mm256_maskstore_pd(&ls[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), s2);
281: w0 = _mm256_setzero_pd(); w1 = _mm256_setzero_pd(); w2 = _mm256_setzero_pd();
283: /* first row */
284: v0 = _mm256_set1_pd(ls[0]);
285: a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[0]); w0 = _mm256_fmadd_pd(a0,v0,w0);
286: a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[4]); w1 = _mm256_fmadd_pd(a1,v0,w1);
287: a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[8]); w2 = _mm256_fmadd_pd(a2,v0,w2);
289: /* second row */
290: v1 = _mm256_set1_pd(ls[1]);
291: a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[9]); w0 = _mm256_fmadd_pd(a3,v1,w0);
292: a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[13]); w1 = _mm256_fmadd_pd(a4,v1,w1);
293: a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[17]); w2 = _mm256_fmadd_pd(a5,v1,w2);
295: /* third row */
296: v2 = _mm256_set1_pd(ls[2]);
297: a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[18]); w0 = _mm256_fmadd_pd(a0,v2,w0);
298: a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[22]); w1 = _mm256_fmadd_pd(a1,v2,w1);
299: a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[26]); w2 = _mm256_fmadd_pd(a2,v2,w2);
301: /* fourth row */
302: v3 = _mm256_set1_pd(ls[3]);
303: a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[27]); w0 = _mm256_fmadd_pd(a3,v3,w0);
304: a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[31]); w1 = _mm256_fmadd_pd(a4,v3,w1);
305: a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[35]); w2 = _mm256_fmadd_pd(a5,v3,w2);
307: /* fifth row */
308: v0 = _mm256_set1_pd(ls[4]);
309: a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[36]); w0 = _mm256_fmadd_pd(a0,v0,w0);
310: a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[40]); w1 = _mm256_fmadd_pd(a1,v0,w1);
311: a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[44]); w2 = _mm256_fmadd_pd(a2,v0,w2);
313: /* sixth row */
314: v1 = _mm256_set1_pd(ls[5]);
315: a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[45]); w0 = _mm256_fmadd_pd(a3,v1,w0);
316: a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[49]); w1 = _mm256_fmadd_pd(a4,v1,w1);
317: a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[53]); w2 = _mm256_fmadd_pd(a5,v1,w2);
319: /* seventh row */
320: v2 = _mm256_set1_pd(ls[6]);
321: a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[54]); w0 = _mm256_fmadd_pd(a0,v2,w0);
322: a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[58]); w1 = _mm256_fmadd_pd(a1,v2,w1);
323: a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[62]); w2 = _mm256_fmadd_pd(a2,v2,w2);
325: /* eighth row */
326: v3 = _mm256_set1_pd(ls[7]);
327: a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[63]); w0 = _mm256_fmadd_pd(a3,v3,w0);
328: a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[67]); w1 = _mm256_fmadd_pd(a4,v3,w1);
329: a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[71]); w2 = _mm256_fmadd_pd(a5,v3,w2);
331: /* ninth row */
332: v0 = _mm256_set1_pd(ls[8]);
333: a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[72]); w0 = _mm256_fmadd_pd(a3,v0,w0);
334: a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[76]); w1 = _mm256_fmadd_pd(a4,v0,w1);
335: a2 = _mm256_maskload_pd((&(aa+bs2*adiag[i])[80]), _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
336: w2 = _mm256_fmadd_pd(a2,v0,w2);
338: _mm256_storeu_pd(&(t+i*bs)[0], w0); _mm256_storeu_pd(&(t+i*bs)[4], w1); _mm256_maskstore_pd(&(t+i*bs)[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), w2);
340: PetscArraycpy(x+i*bs,t+i*bs,bs);
341: }
343: VecRestoreArrayRead(bb,&b);
344: VecRestoreArray(xx,&x);
345: PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n);
346: return(0);
347: }
348: #endif