Actual source code: baijfact81.c

petsc-3.11.4 2019-09-28
Report Typos and Errors

  2: /*
  3:    Factorization code for BAIJ format.
  4:  */
  5:  #include <../src/mat/impls/baij/seq/baij.h>
  6:  #include <petsc/private/kernels/blockinvert.h>
  7: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
  8: #include <immintrin.h>
  9: #endif
 10: /*
 11:    Version for when blocks are 9 by 9
 12:  */
 13: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
 14: PetscErrorCode MatLUFactorNumeric_SeqBAIJ_9_NaturalOrdering(Mat B,Mat A,const MatFactorInfo *info)
 15: {
 16:   Mat            C =B;
 17:   Mat_SeqBAIJ    *a=(Mat_SeqBAIJ*)A->data,*b=(Mat_SeqBAIJ*)C->data;
 19:   PetscInt       i,j,k,nz,nzL,row;
 20:   const PetscInt n=a->mbs,*ai=a->i,*aj=a->j,*bi=b->i,*bj=b->j;
 21:   const PetscInt *ajtmp,*bjtmp,*bdiag=b->diag,*pj,bs2=a->bs2;
 22:   MatScalar      *rtmp,*pc,*mwork,*v,*pv,*aa=a->a;
 23:   PetscInt       flg;
 24:   PetscReal      shift = info->shiftamount;
 25:   PetscBool      allowzeropivot,zeropivotdetected;

 28:   allowzeropivot = PetscNot(A->erroriffailure);

 30:   /* generate work space needed by the factorization */
 31:   PetscMalloc2(bs2*n,&rtmp,bs2,&mwork);
 32:   PetscMemzero(rtmp,bs2*n*sizeof(MatScalar));

 34:   for (i=0; i<n; i++) {
 35:     /* zero rtmp */
 36:     /* L part */
 37:     nz    = bi[i+1] - bi[i];
 38:     bjtmp = bj + bi[i];
 39:     for  (j=0; j<nz; j++) {
 40:       PetscMemzero(rtmp+bs2*bjtmp[j],bs2*sizeof(MatScalar));
 41:     }

 43:     /* U part */
 44:     nz    = bdiag[i] - bdiag[i+1];
 45:     bjtmp = bj + bdiag[i+1]+1;
 46:     for  (j=0; j<nz; j++) {
 47:       PetscMemzero(rtmp+bs2*bjtmp[j],bs2*sizeof(MatScalar));
 48:     }

 50:     /* load in initial (unfactored row) */
 51:     nz    = ai[i+1] - ai[i];
 52:     ajtmp = aj + ai[i];
 53:     v     = aa + bs2*ai[i];
 54:     for (j=0; j<nz; j++) {
 55:       PetscMemcpy(rtmp+bs2*ajtmp[j],v+bs2*j,bs2*sizeof(MatScalar));
 56:     }

 58:     /* elimination */
 59:     bjtmp = bj + bi[i];
 60:     nzL   = bi[i+1] - bi[i];
 61:     for (k=0; k < nzL; k++) {
 62:       row = bjtmp[k];
 63:       pc  = rtmp + bs2*row;
 64:       for (flg=0,j=0; j<bs2; j++) {
 65:         if (pc[j]!=0.0) {
 66:           flg = 1;
 67:           break;
 68:         }
 69:       }
 70:       if (flg) {
 71:         pv = b->a + bs2*bdiag[row];
 72:         /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
 73:         PetscKernel_A_gets_A_times_B_9(pc,pv,mwork);

 75:         pj = b->j + bdiag[row+1]+1; /* begining of U(row,:) */
 76:         pv = b->a + bs2*(bdiag[row+1]+1);
 77:         nz = bdiag[row] - bdiag[row+1] - 1; /* num of entries inU(row,:), excluding diag */
 78:         for (j=0; j<nz; j++) {
 79:           /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
 80:           /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
 81:           v    = rtmp + bs2*pj[j];
 82:           PetscKernel_A_gets_A_minus_B_times_C_9(v,pc,pv);
 83:           /* pv incremented in PetscKernel_A_gets_A_minus_B_times_C_9 */
 84:         }
 85:         PetscLogFlops(1458*nz+1377); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
 86:       }
 87:     }

 89:     /* finished row so stick it into b->a */
 90:     /* L part */
 91:     pv = b->a + bs2*bi[i];
 92:     pj = b->j + bi[i];
 93:     nz = bi[i+1] - bi[i];
 94:     for (j=0; j<nz; j++) {
 95:       PetscMemcpy(pv+bs2*j,rtmp+bs2*pj[j],bs2*sizeof(MatScalar));
 96:     }

 98:     /* Mark diagonal and invert diagonal for simplier triangular solves */
 99:     pv   = b->a + bs2*bdiag[i];
100:     pj   = b->j + bdiag[i];
101:     PetscMemcpy(pv,rtmp+bs2*pj[0],bs2*sizeof(MatScalar));
102:     PetscKernel_A_gets_inverse_A_9(pv,shift,allowzeropivot,&zeropivotdetected);
103:     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;

105:     /* U part */
106:     pv = b->a + bs2*(bdiag[i+1]+1);
107:     pj = b->j + bdiag[i+1]+1;
108:     nz = bdiag[i] - bdiag[i+1] - 1;
109:     for (j=0; j<nz; j++) {
110:       PetscMemcpy(pv+bs2*j,rtmp+bs2*pj[j],bs2*sizeof(MatScalar));
111:     }
112:   }
113:   PetscFree2(rtmp,mwork);

115:   C->ops->solve          = MatSolve_SeqBAIJ_9_NaturalOrdering;
116:   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_N;
117:   C->assembled           = PETSC_TRUE;

119:   PetscLogFlops(1.333333333333*9*9*9*n); /* from inverting diagonal blocks */
120:   return(0);
121: }

123: PetscErrorCode MatSolve_SeqBAIJ_9_NaturalOrdering(Mat A,Vec bb,Vec xx)
124: {
125:   Mat_SeqBAIJ    *a=(Mat_SeqBAIJ*)A->data;
127:   const PetscInt *ai=a->i,*aj=a->j,*adiag=a->diag,*vi;
128:   PetscInt       i,k,n=a->mbs;
129:   PetscInt       nz,bs=A->rmap->bs,bs2=a->bs2;
130:   const MatScalar   *aa=a->a,*v;
131:   PetscScalar       *x,*s,*t,*ls;
132:   const PetscScalar *b;
133:   __m256d a0,a1,a2,a3,a4,a5,w0,w1,w2,w3,s0,s1,s2,v0,v1,v2,v3;

136:   VecGetArrayRead(bb,&b);
137:   VecGetArray(xx,&x);
138:   t    = a->solve_work;

140:   /* forward solve the lower triangular */
141:   PetscMemcpy(t,b,bs*sizeof(PetscScalar)); /* copy 1st block of b to t */

143:   for (i=1; i<n; i++) {
144:     v    = aa + bs2*ai[i];
145:     vi   = aj + ai[i];
146:     nz   = ai[i+1] - ai[i];
147:     s    = t + bs*i;
148:     PetscMemcpy(s,b+bs*i,bs*sizeof(PetscScalar)); /* copy i_th block of b to t */

150:     __m256d s0,s1,s2;
151:     s0 = _mm256_loadu_pd(s+0);
152:     s1 = _mm256_loadu_pd(s+4);
153:     s2 = _mm256_maskload_pd(s+8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));

155:     for (k=0;k<nz;k++) {

157:       w0 = _mm256_set1_pd((t+bs*vi[k])[0]);
158:       a0 = _mm256_loadu_pd(&v[ 0]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
159:       a1 = _mm256_loadu_pd(&v[ 4]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
160:       a2 = _mm256_loadu_pd(&v[ 8]); s2 = _mm256_fnmadd_pd(a2,w0,s2);

162:       w1 = _mm256_set1_pd((t+bs*vi[k])[1]);
163:       a3 = _mm256_loadu_pd(&v[ 9]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
164:       a4 = _mm256_loadu_pd(&v[13]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
165:       a5 = _mm256_loadu_pd(&v[17]); s2 = _mm256_fnmadd_pd(a5,w1,s2);

167:       w2 = _mm256_set1_pd((t+bs*vi[k])[2]);
168:       a0 = _mm256_loadu_pd(&v[18]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
169:       a1 = _mm256_loadu_pd(&v[22]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
170:       a2 = _mm256_loadu_pd(&v[26]); s2 = _mm256_fnmadd_pd(a2,w2,s2);

172:       w3 = _mm256_set1_pd((t+bs*vi[k])[3]);
173:       a3 = _mm256_loadu_pd(&v[27]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
174:       a4 = _mm256_loadu_pd(&v[31]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
175:       a5 = _mm256_loadu_pd(&v[35]); s2 = _mm256_fnmadd_pd(a5,w3,s2);

177:       w0 = _mm256_set1_pd((t+bs*vi[k])[4]);
178:       a0 = _mm256_loadu_pd(&v[36]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
179:       a1 = _mm256_loadu_pd(&v[40]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
180:       a2 = _mm256_loadu_pd(&v[44]); s2 = _mm256_fnmadd_pd(a2,w0,s2);

182:       w1 = _mm256_set1_pd((t+bs*vi[k])[5]);
183:       a3 = _mm256_loadu_pd(&v[45]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
184:       a4 = _mm256_loadu_pd(&v[49]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
185:       a5 = _mm256_loadu_pd(&v[53]); s2 = _mm256_fnmadd_pd(a5,w1,s2);

187:       w2 = _mm256_set1_pd((t+bs*vi[k])[6]);
188:       a0 = _mm256_loadu_pd(&v[54]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
189:       a1 = _mm256_loadu_pd(&v[58]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
190:       a2 = _mm256_loadu_pd(&v[62]); s2 = _mm256_fnmadd_pd(a2,w2,s2);

192:       w3 = _mm256_set1_pd((t+bs*vi[k])[7]);
193:       a3 = _mm256_loadu_pd(&v[63]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
194:       a4 = _mm256_loadu_pd(&v[67]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
195:       a5 = _mm256_loadu_pd(&v[71]); s2 = _mm256_fnmadd_pd(a5,w3,s2);

197:       w0 = _mm256_set1_pd((t+bs*vi[k])[8]);
198:       a0 = _mm256_loadu_pd(&v[72]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
199:       a1 = _mm256_loadu_pd(&v[76]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
200:       a2 = _mm256_maskload_pd(v+80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
201:       s2 = _mm256_fnmadd_pd(a2,w0,s2);
202:       v += bs2;
203:     }
204:          _mm256_storeu_pd(&s[0], s0);
205:          _mm256_storeu_pd(&s[4], s1);
206:          _mm256_maskstore_pd(&s[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), s2);
207:   }

209:   /* backward solve the upper triangular */
210:   ls = a->solve_work + A->cmap->n;
211:   for (i=n-1; i>=0; i--) {
212:     v    = aa + bs2*(adiag[i+1]+1);
213:     vi   = aj + adiag[i+1]+1;
214:     nz   = adiag[i] - adiag[i+1]-1;
215:     PetscMemcpy(ls,t+i*bs,bs*sizeof(PetscScalar));

217:     s0 = _mm256_loadu_pd(ls+0);
218:     s1 = _mm256_loadu_pd(ls+4);
219:     s2 = _mm256_maskload_pd(ls+8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));

221:     for (k=0; k<nz; k++) {

223:       w0 = _mm256_set1_pd((t+bs*vi[k])[0]);
224:       a0 = _mm256_loadu_pd(&v[ 0]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
225:       a1 = _mm256_loadu_pd(&v[ 4]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
226:       a2 = _mm256_loadu_pd(&v[ 8]); s2 = _mm256_fnmadd_pd(a2,w0,s2);

228:       //v += 9;
229:       w1 = _mm256_set1_pd((t+bs*vi[k])[1]);
230:       a3 = _mm256_loadu_pd(&v[ 9]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
231:       a4 = _mm256_loadu_pd(&v[13]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
232:       a5 = _mm256_loadu_pd(&v[17]); s2 = _mm256_fnmadd_pd(a5,w1,s2);

234:       //v += 9;
235:       w2 = _mm256_set1_pd((t+bs*vi[k])[2]);
236:       a0 = _mm256_loadu_pd(&v[18]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
237:       a1 = _mm256_loadu_pd(&v[22]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
238:       a2 = _mm256_loadu_pd(&v[26]); s2 = _mm256_fnmadd_pd(a2,w2,s2);

240:       //v += 9;
241:       w3 = _mm256_set1_pd((t+bs*vi[k])[3]);
242:       a3 = _mm256_loadu_pd(&v[27]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
243:       a4 = _mm256_loadu_pd(&v[31]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
244:       a5 = _mm256_loadu_pd(&v[35]); s2 = _mm256_fnmadd_pd(a5,w3,s2);

246:       //v += 9;
247:       w0 = _mm256_set1_pd((t+bs*vi[k])[4]);
248:       a0 = _mm256_loadu_pd(&v[36]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
249:       a1 = _mm256_loadu_pd(&v[40]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
250:       a2 = _mm256_loadu_pd(&v[44]); s2 = _mm256_fnmadd_pd(a2,w0,s2);

252:       //v += 9;
253:       w1 = _mm256_set1_pd((t+bs*vi[k])[5]);
254:       a3 = _mm256_loadu_pd(&v[45]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
255:       a4 = _mm256_loadu_pd(&v[49]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
256:       a5 = _mm256_loadu_pd(&v[53]); s2 = _mm256_fnmadd_pd(a5,w1,s2);

258:       //v += 9;
259:       w2 = _mm256_set1_pd((t+bs*vi[k])[6]);
260:       a0 = _mm256_loadu_pd(&v[54]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
261:       a1 = _mm256_loadu_pd(&v[58]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
262:       a2 = _mm256_loadu_pd(&v[62]); s2 = _mm256_fnmadd_pd(a2,w2,s2);

264:       //v += 9;
265:       w3 = _mm256_set1_pd((t+bs*vi[k])[7]);
266:       a3 = _mm256_loadu_pd(&v[63]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
267:       a4 = _mm256_loadu_pd(&v[67]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
268:       a5 = _mm256_loadu_pd(&v[71]); s2 = _mm256_fnmadd_pd(a5,w3,s2);

270:       //v += 9;
271:       w0 = _mm256_set1_pd((t+bs*vi[k])[8]);
272:       a0 = _mm256_loadu_pd(&v[72]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
273:       a1 = _mm256_loadu_pd(&v[76]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
274:       a2 = _mm256_maskload_pd(v+80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
275:       s2 = _mm256_fnmadd_pd(a2,w0,s2);
276:       v += bs2;
277:     }

279:          _mm256_storeu_pd(&ls[0], s0); _mm256_storeu_pd(&ls[4], s1); _mm256_maskstore_pd(&ls[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), s2);

281:     w0 = _mm256_setzero_pd(); w1 = _mm256_setzero_pd(); w2 = _mm256_setzero_pd();

283:     // first row
284:     v0 = _mm256_set1_pd(ls[0]);
285:     a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[0]); w0 = _mm256_fmadd_pd(a0,v0,w0);
286:     a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[4]); w1 = _mm256_fmadd_pd(a1,v0,w1);
287:     a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[8]); w2 = _mm256_fmadd_pd(a2,v0,w2);

289:     // second row
290:     v1 = _mm256_set1_pd(ls[1]);
291:     a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[9]); w0 = _mm256_fmadd_pd(a3,v1,w0);
292:     a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[13]); w1 = _mm256_fmadd_pd(a4,v1,w1);
293:     a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[17]); w2 = _mm256_fmadd_pd(a5,v1,w2);

295:     // third row
296:     v2 = _mm256_set1_pd(ls[2]);
297:     a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[18]); w0 = _mm256_fmadd_pd(a0,v2,w0);
298:     a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[22]); w1 = _mm256_fmadd_pd(a1,v2,w1);
299:     a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[26]); w2 = _mm256_fmadd_pd(a2,v2,w2);

301:     // fourth row
302:     v3 = _mm256_set1_pd(ls[3]);
303:     a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[27]); w0 = _mm256_fmadd_pd(a3,v3,w0);
304:     a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[31]); w1 = _mm256_fmadd_pd(a4,v3,w1);
305:     a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[35]); w2 = _mm256_fmadd_pd(a5,v3,w2);

307:     // fifth row
308:     v0 = _mm256_set1_pd(ls[4]);
309:     a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[36]); w0 = _mm256_fmadd_pd(a0,v0,w0);
310:     a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[40]); w1 = _mm256_fmadd_pd(a1,v0,w1);
311:     a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[44]); w2 = _mm256_fmadd_pd(a2,v0,w2);

313:     // sixth row
314:     v1 = _mm256_set1_pd(ls[5]);
315:     a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[45]); w0 = _mm256_fmadd_pd(a3,v1,w0);
316:     a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[49]); w1 = _mm256_fmadd_pd(a4,v1,w1);
317:     a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[53]); w2 = _mm256_fmadd_pd(a5,v1,w2);

319:     // seventh row
320:     v2 = _mm256_set1_pd(ls[6]);
321:     a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[54]); w0 = _mm256_fmadd_pd(a0,v2,w0);
322:     a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[58]); w1 = _mm256_fmadd_pd(a1,v2,w1);
323:     a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[62]); w2 = _mm256_fmadd_pd(a2,v2,w2);

325:     // eighth row
326:     v3 = _mm256_set1_pd(ls[7]);
327:     a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[63]); w0 = _mm256_fmadd_pd(a3,v3,w0);
328:     a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[67]); w1 = _mm256_fmadd_pd(a4,v3,w1);
329:     a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[71]); w2 = _mm256_fmadd_pd(a5,v3,w2);

331:     // ninth row
332:     v0 = _mm256_set1_pd(ls[8]);
333:     a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[72]); w0 = _mm256_fmadd_pd(a3,v0,w0);
334:     a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[76]); w1 = _mm256_fmadd_pd(a4,v0,w1);
335:     a2 = _mm256_maskload_pd((&(aa+bs2*adiag[i])[80]), _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
336:     w2 = _mm256_fmadd_pd(a2,v0,w2);

338:     _mm256_storeu_pd(&(t+i*bs)[0], w0); _mm256_storeu_pd(&(t+i*bs)[4], w1); _mm256_maskstore_pd(&(t+i*bs)[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), w2);

340:     PetscMemcpy(x+i*bs,t+i*bs,bs*sizeof(PetscScalar));
341:   }

343:   VecRestoreArrayRead(bb,&b);
344:   VecRestoreArray(xx,&x);
345:   PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n);
346:   return(0);
347: }
348: #endif