Actual source code: baijsolvtran2.c
petsc-3.14.6 2021-03-30
1: #include <../src/mat/impls/baij/seq/baij.h>
2: #include <petsc/private/kernels/blockinvert.h>
4: PetscErrorCode MatSolveTranspose_SeqBAIJ_2_inplace(Mat A,Vec bb,Vec xx)
5: {
6: Mat_SeqBAIJ *a =(Mat_SeqBAIJ*)A->data;
7: IS iscol=a->col,isrow=a->row;
8: PetscErrorCode ierr;
9: const PetscInt *r,*c,*rout,*cout;
10: const PetscInt *diag=a->diag,n=a->mbs,*vi,*ai=a->i,*aj=a->j;
11: PetscInt i,nz,idx,idt,ii,ic,ir,oidx;
12: const MatScalar *aa=a->a,*v;
13: PetscScalar s1,s2,x1,x2,*x,*t;
14: const PetscScalar *b;
17: VecGetArrayRead(bb,&b);
18: VecGetArray(xx,&x);
19: t = a->solve_work;
21: ISGetIndices(isrow,&rout); r = rout;
22: ISGetIndices(iscol,&cout); c = cout;
24: /* copy the b into temp work space according to permutation */
25: ii = 0;
26: for (i=0; i<n; i++) {
27: ic = 2*c[i];
28: t[ii] = b[ic];
29: t[ii+1] = b[ic+1];
30: ii += 2;
31: }
33: /* forward solve the U^T */
34: idx = 0;
35: for (i=0; i<n; i++) {
37: v = aa + 4*diag[i];
38: /* multiply by the inverse of the block diagonal */
39: x1 = t[idx]; x2 = t[1+idx];
40: s1 = v[0]*x1 + v[1]*x2;
41: s2 = v[2]*x1 + v[3]*x2;
42: v += 4;
44: vi = aj + diag[i] + 1;
45: nz = ai[i+1] - diag[i] - 1;
46: while (nz--) {
47: oidx = 2*(*vi++);
48: t[oidx] -= v[0]*s1 + v[1]*s2;
49: t[oidx+1] -= v[2]*s1 + v[3]*s2;
50: v += 4;
51: }
52: t[idx] = s1;t[1+idx] = s2;
53: idx += 2;
54: }
55: /* backward solve the L^T */
56: for (i=n-1; i>=0; i--) {
57: v = aa + 4*diag[i] - 4;
58: vi = aj + diag[i] - 1;
59: nz = diag[i] - ai[i];
60: idt = 2*i;
61: s1 = t[idt]; s2 = t[1+idt];
62: while (nz--) {
63: idx = 2*(*vi--);
64: t[idx] -= v[0]*s1 + v[1]*s2;
65: t[idx+1] -= v[2]*s1 + v[3]*s2;
66: v -= 4;
67: }
68: }
70: /* copy t into x according to permutation */
71: ii = 0;
72: for (i=0; i<n; i++) {
73: ir = 2*r[i];
74: x[ir] = t[ii];
75: x[ir+1] = t[ii+1];
76: ii += 2;
77: }
79: ISRestoreIndices(isrow,&rout);
80: ISRestoreIndices(iscol,&cout);
81: VecRestoreArrayRead(bb,&b);
82: VecRestoreArray(xx,&x);
83: PetscLogFlops(2.0*4*(a->nz) - 2.0*A->cmap->n);
84: return(0);
85: }
87: PetscErrorCode MatSolveTranspose_SeqBAIJ_2(Mat A,Vec bb,Vec xx)
88: {
89: Mat_SeqBAIJ *a=(Mat_SeqBAIJ*)A->data;
90: PetscErrorCode ierr;
91: IS iscol=a->col,isrow=a->row;
92: const PetscInt n =a->mbs,*vi,*ai=a->i,*aj=a->j,*diag=a->diag;
93: const PetscInt *r,*c,*rout,*cout;
94: PetscInt nz,idx,idt,j,i,oidx,ii,ic,ir;
95: const PetscInt bs =A->rmap->bs,bs2=a->bs2;
96: const MatScalar *aa=a->a,*v;
97: PetscScalar s1,s2,x1,x2,*x,*t;
98: const PetscScalar *b;
101: VecGetArrayRead(bb,&b);
102: VecGetArray(xx,&x);
103: t = a->solve_work;
105: ISGetIndices(isrow,&rout); r = rout;
106: ISGetIndices(iscol,&cout); c = cout;
108: /* copy b into temp work space according to permutation */
109: for (i=0; i<n; i++) {
110: ii = bs*i; ic = bs*c[i];
111: t[ii] = b[ic]; t[ii+1] = b[ic+1];
112: }
114: /* forward solve the U^T */
115: idx = 0;
116: for (i=0; i<n; i++) {
117: v = aa + bs2*diag[i];
118: /* multiply by the inverse of the block diagonal */
119: x1 = t[idx]; x2 = t[1+idx];
120: s1 = v[0]*x1 + v[1]*x2;
121: s2 = v[2]*x1 + v[3]*x2;
122: v -= bs2;
124: vi = aj + diag[i] - 1;
125: nz = diag[i] - diag[i+1] - 1;
126: for (j=0; j>-nz; j--) {
127: oidx = bs*vi[j];
128: t[oidx] -= v[0]*s1 + v[1]*s2;
129: t[oidx+1] -= v[2]*s1 + v[3]*s2;
130: v -= bs2;
131: }
132: t[idx] = s1;t[1+idx] = s2;
133: idx += bs;
134: }
135: /* backward solve the L^T */
136: for (i=n-1; i>=0; i--) {
137: v = aa + bs2*ai[i];
138: vi = aj + ai[i];
139: nz = ai[i+1] - ai[i];
140: idt = bs*i;
141: s1 = t[idt]; s2 = t[1+idt];
142: for (j=0; j<nz; j++) {
143: idx = bs*vi[j];
144: t[idx] -= v[0]*s1 + v[1]*s2;
145: t[idx+1] -= v[2]*s1 + v[3]*s2;
146: v += bs2;
147: }
148: }
150: /* copy t into x according to permutation */
151: for (i=0; i<n; i++) {
152: ii = bs*i; ir = bs*r[i];
153: x[ir] = t[ii]; x[ir+1] = t[ii+1];
154: }
156: ISRestoreIndices(isrow,&rout);
157: ISRestoreIndices(iscol,&cout);
158: VecRestoreArrayRead(bb,&b);
159: VecRestoreArray(xx,&x);
160: PetscLogFlops(2.0*bs2*(a->nz) - bs*A->cmap->n);
161: return(0);
162: }