Actual source code: baijsolvtrann.c

petsc-3.14.6 2021-03-30
Report Typos and Errors
  1: #include <../src/mat/impls/baij/seq/baij.h>
  2: #include <petsc/private/kernels/blockinvert.h>

  4: /* ----------------------------------------------------------- */
  5: PetscErrorCode MatSolveTranspose_SeqBAIJ_N_inplace(Mat A,Vec bb,Vec xx)
  6: {
  7:   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
  8:   IS                iscol=a->col,isrow=a->row;
  9:   PetscErrorCode    ierr;
 10:   const PetscInt    *r,*c,*rout,*cout,*ai=a->i,*aj=a->j,*vi;
 11:   PetscInt          i,nz,j;
 12:   const PetscInt    n  =a->mbs,bs=A->rmap->bs,bs2=a->bs2;
 13:   const MatScalar   *aa=a->a,*v;
 14:   PetscScalar       *x,*t,*ls;
 15:   const PetscScalar *b;

 18:   VecGetArrayRead(bb,&b);
 19:   VecGetArray(xx,&x);
 20:   t    = a->solve_work;

 22:   ISGetIndices(isrow,&rout); r = rout;
 23:   ISGetIndices(iscol,&cout); c = cout;

 25:   /* copy the b into temp work space according to permutation */
 26:   for (i=0; i<n; i++) {
 27:     for (j=0; j<bs; j++) {
 28:       t[i*bs+j] = b[c[i]*bs+j];
 29:     }
 30:   }


 33:   /* forward solve the upper triangular transpose */
 34:   ls = a->solve_work + A->cmap->n;
 35:   for (i=0; i<n; i++) {
 36:     PetscArraycpy(ls,t+i*bs,bs);
 37:     PetscKernel_w_gets_transA_times_v(bs,ls,aa+bs2*a->diag[i],t+i*bs);
 38:     v  = aa + bs2*(a->diag[i] + 1);
 39:     vi = aj + a->diag[i] + 1;
 40:     nz = ai[i+1] - a->diag[i] - 1;
 41:     while (nz--) {
 42:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(*vi++),v,t+i*bs);
 43:       v += bs2;
 44:     }
 45:   }

 47:   /* backward solve the lower triangular transpose */
 48:   for (i=n-1; i>=0; i--) {
 49:     v  = aa + bs2*ai[i];
 50:     vi = aj + ai[i];
 51:     nz = a->diag[i] - ai[i];
 52:     while (nz--) {
 53:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(*vi++),v,t+i*bs);
 54:       v += bs2;
 55:     }
 56:   }

 58:   /* copy t into x according to permutation */
 59:   for (i=0; i<n; i++) {
 60:     for (j=0; j<bs; j++) {
 61:       x[bs*r[i]+j]   = t[bs*i+j];
 62:     }
 63:   }

 65:   ISRestoreIndices(isrow,&rout);
 66:   ISRestoreIndices(iscol,&cout);
 67:   VecRestoreArrayRead(bb,&b);
 68:   VecRestoreArray(xx,&x);
 69:   PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n);
 70:   return(0);
 71: }

 73: PetscErrorCode MatSolveTranspose_SeqBAIJ_N(Mat A,Vec bb,Vec xx)
 74: {
 75:   Mat_SeqBAIJ       *a   =(Mat_SeqBAIJ*)A->data;
 76:   IS                iscol=a->col,isrow=a->row;
 77:   PetscErrorCode    ierr;
 78:   const PetscInt    *r,*c,*rout,*cout;
 79:   const PetscInt    n=a->mbs,*ai=a->i,*aj=a->j,*vi,*diag=a->diag;
 80:   PetscInt          i,j,nz;
 81:   const PetscInt    bs =A->rmap->bs,bs2=a->bs2;
 82:   const MatScalar   *aa=a->a,*v;
 83:   PetscScalar       *x,*t,*ls;
 84:   const PetscScalar *b;

 87:   VecGetArrayRead(bb,&b);
 88:   VecGetArray(xx,&x);
 89:   t    = a->solve_work;

 91:   ISGetIndices(isrow,&rout); r = rout;
 92:   ISGetIndices(iscol,&cout); c = cout;

 94:   /* copy the b into temp work space according to permutation */
 95:   for (i=0; i<n; i++) {
 96:     for (j=0; j<bs; j++) {
 97:       t[i*bs+j] = b[c[i]*bs+j];
 98:     }
 99:   }


102:   /* forward solve the upper triangular transpose */
103:   ls = a->solve_work + A->cmap->n;
104:   for (i=0; i<n; i++) {
105:     PetscArraycpy(ls,t+i*bs,bs);
106:     PetscKernel_w_gets_transA_times_v(bs,ls,aa+bs2*diag[i],t+i*bs);
107:     v  = aa + bs2*(diag[i] - 1);
108:     vi = aj + diag[i] - 1;
109:     nz = diag[i] - diag[i+1] - 1;
110:     for (j=0; j>-nz; j--) {
111:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(vi[j]),v,t+i*bs);
112:       v -= bs2;
113:     }
114:   }

116:   /* backward solve the lower triangular transpose */
117:   for (i=n-1; i>=0; i--) {
118:     v  = aa + bs2*ai[i];
119:     vi = aj + ai[i];
120:     nz = ai[i+1] - ai[i];
121:     for (j=0; j<nz; j++) {
122:       PetscKernel_v_gets_v_minus_transA_times_w(bs,t+bs*(vi[j]),v,t+i*bs);
123:       v += bs2;
124:     }
125:   }

127:   /* copy t into x according to permutation */
128:   for (i=0; i<n; i++) {
129:     for (j=0; j<bs; j++) {
130:       x[bs*r[i]+j]   = t[bs*i+j];
131:     }
132:   }

134:   ISRestoreIndices(isrow,&rout);
135:   ISRestoreIndices(iscol,&cout);
136:   VecRestoreArrayRead(bb,&b);
137:   VecRestoreArray(xx,&x);
138:   PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n);
139:   return(0);
140: }