Actual source code: submat.c

  1: #define PETSCMAT_DLL

 3:  #include private/matimpl.h

  5: typedef struct {
  6:   IS isrow,iscol;               /* rows and columns in submatrix, only used to check consistency */
  7:   Vec left,right;               /* optional scaling */
  8:   Vec olwork,orwork;            /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
  9:   Vec lwork,rwork;              /* work vectors inside the scatters */
 10:   VecScatter lrestrict,rprolong;
 11:   Mat A;
 12:   PetscScalar scale;
 13: } Mat_SubMatrix;

 17: static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
 18: {
 19:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

 23:   if (!Na->left) {
 24:     *xx = x;
 25:   } else {
 26:     if (!Na->olwork) {
 27:       VecDuplicate(Na->left,&Na->olwork);
 28:     }
 29:     VecPointwiseMult(Na->left,x,Na->olwork);
 30:     *xx = Na->olwork;
 31:   }
 32:   return(0);
 33: }

 37: static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
 38: {
 39:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

 43:   if (!Na->right) {
 44:     *xx = x;
 45:   } else {
 46:     if (!Na->orwork) {
 47:       VecDuplicate(Na->right,&Na->orwork);
 48:     }
 49:     VecPointwiseMult(Na->right,x,Na->orwork);
 50:     *xx = Na->orwork;
 51:   }
 52:   return(0);
 53: }

 57: static PetscErrorCode PostScaleLeft(Mat N,Vec x)
 58: {
 59:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

 63:   if (Na->left) {
 64:     VecPointwiseMult(x,x,Na->left);
 65:   }
 66:   return(0);
 67: }

 71: static PetscErrorCode PostScaleRight(Mat N,Vec x)
 72: {
 73:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

 77:   if (Na->right) {
 78:     VecPointwiseMult(x,x,Na->right);
 79:   }
 80:   return(0);
 81: }

 85: static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
 86: {
 87:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

 90:   Na->scale *= scale;
 91:   return(0);
 92: }

 96: static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
 97: {
 98:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

102:   if (left) {
103:     if (!Na->left) {
104:       VecDuplicate(left,&Na->left);
105:       VecCopy(left,Na->left);
106:     } else {
107:       VecPointwiseMult(Na->left,left,Na->left);
108:     }
109:   }
110:   if (right) {
111:     if (!Na->right) {
112:       VecDuplicate(right,&Na->right);
113:       VecCopy(right,Na->right);
114:     } else {
115:       VecPointwiseMult(Na->right,right,Na->right);
116:     }
117:   }
118:   return(0);
119: }

123: static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
124: {
125:   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
126:   Vec             xx=0;
127:   PetscErrorCode  ierr;

130:   PreScaleRight(N,x,&xx);
131:   VecZeroEntries(Na->rwork);
132:   VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
133:   VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
134:   MatMult(Na->A,Na->rwork,Na->lwork);
135:   VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);
136:   VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);
137:   PostScaleLeft(N,y);
138:   VecScale(y,Na->scale);
139:   return(0);
140: }

144: static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
145: {
146:   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
147:   Vec             xx=0;
148:   PetscErrorCode  ierr;

151:   PreScaleRight(N,v1,&xx);
152:   VecZeroEntries(Na->rwork);
153:   VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
154:   VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
155:   MatMult(Na->A,Na->rwork,Na->lwork);
156:   VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);
157:   VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);
158:   PostScaleLeft(N,v3);
159:   VecAYPX(v3,Na->scale,v2);
160:   return(0);
161: }

165: static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
166: {
167:   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
168:   Vec             xx=0;

172:   PreScaleLeft(N,x,&xx);
173:   VecZeroEntries(Na->lwork);
174:   VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
175:   VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
176:   MatMultTranspose(Na->A,Na->lwork,Na->rwork);
177:   VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_FORWARD);
178:   VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_FORWARD);
179:   PostScaleRight(N,y);
180:   VecScale(y,Na->scale);
181:   return(0);
182: }

186: static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
187: {
188:   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
189:   Vec             xx =0;

193:   PreScaleLeft(N,v1,&xx);
194:   VecZeroEntries(Na->lwork);
195:   VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
196:   VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
197:   MatMultTranspose(Na->A,Na->lwork,Na->rwork);
198:   VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_FORWARD);
199:   VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_FORWARD);
200:   PostScaleRight(N,v3);
201:   VecAYPX(v3,Na->scale,v2);
202:   return(0);
203: }

207: static PetscErrorCode MatDestroy_SubMatrix(Mat N)
208: {
209:   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;

213:   ISDestroy(Na->isrow);
214:   ISDestroy(Na->iscol);
215:   if (Na->left) {VecDestroy(Na->left);}
216:   if (Na->right) {VecDestroy(Na->right);}
217:   if (Na->olwork) {VecDestroy(Na->olwork);}
218:   if (Na->orwork) {VecDestroy(Na->orwork);}
219:   VecDestroy(Na->lwork);
220:   VecDestroy(Na->rwork);
221:   VecScatterDestroy(Na->lrestrict);
222:   VecScatterDestroy(Na->rprolong);
223:   MatDestroy(Na->A);
224:   PetscFree(Na);
225:   return(0);
226: }

230: /*@
231:    MatCreateSubMatrix - Creates a composite matrix that acts as a submatrix

233:    Collective on Mat

235:    Input Parameters:
236: +  A - matrix that we will extract a submatrix of
237: .  isrow - rows to be present in the submatrix
238: -  iscol - columns to be present in the submatrix

240:    Output Parameters:
241: .  newmat - new matrix

243:    Level: developer

245:    Notes:
246:    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.

248: .seealso: MatGetSubMatrix(), MatSubMatrixUpdate()
249: @*/
250: PetscErrorCode  MatCreateSubMatrix(Mat A,IS isrow,IS iscol,Mat *newmat)
251: {
252:   Vec            left,right;
253:   PetscInt       m,n;
254:   Mat            N;
255:   Mat_SubMatrix *Na;

263:   *newmat = 0;

265:   MatCreate(((PetscObject)A)->comm,&N);
266:   ISGetLocalSize(isrow,&m);
267:   ISGetLocalSize(iscol,&n);
268:   MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);
269:   PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);

271:   PetscNewLog(N,Mat_SubMatrix,&Na);
272:   N->data   = (void*)Na;
273:   PetscObjectReference((PetscObject)A);
274:   PetscObjectReference((PetscObject)isrow);
275:   PetscObjectReference((PetscObject)iscol);
276:   Na->A     = A;
277:   Na->isrow = isrow;
278:   Na->iscol = iscol;
279:   Na->scale = 1.0;

281:   N->ops->destroy          = MatDestroy_SubMatrix;
282:   N->ops->mult             = MatMult_SubMatrix;
283:   N->ops->multadd          = MatMultAdd_SubMatrix;
284:   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
285:   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
286:   N->ops->scale            = MatScale_SubMatrix;
287:   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;

289:   N->assembled = PETSC_TRUE;

291:   PetscLayoutSetBlockSize(N->rmap,A->rmap->bs);
292:   PetscLayoutSetBlockSize(N->cmap,A->cmap->bs);
293:   PetscLayoutSetUp(N->rmap);
294:   PetscLayoutSetUp(N->cmap);

296:   MatGetVecs(A,&Na->rwork,&Na->lwork);
297:   VecCreate(((PetscObject)isrow)->comm,&left);
298:   VecCreate(((PetscObject)iscol)->comm,&right);
299:   VecSetSizes(left,m,PETSC_DETERMINE);
300:   VecSetSizes(right,n,PETSC_DETERMINE);
301:   VecSetUp(left);
302:   VecSetUp(right);
303:   VecScatterCreate(Na->lwork,isrow,left,PETSC_NULL,&Na->lrestrict);
304:   VecScatterCreate(right,PETSC_NULL,Na->rwork,iscol,&Na->rprolong);
305:   VecDestroy(left);
306:   VecDestroy(right);

308:   *newmat = N;
309:   return(0);
310: }


315: /*@
316:    MatSubMatrixUpdate - Updates a submatrix

318:    Collective on Mat

320:    Input Parameters:
321: +  N - submatrix to update
322: .  A - full matrix in the submatrix
323: .  isrow - rows in the update (same as the first time the submatrix was created)
324: -  iscol - columns in the update (same as the first time the submatrix was created)

326:    Level: developer

328:    Notes:
329:    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.

331: .seealso: MatGetSubMatrix(), MatCreateSubMatrix()
332: @*/
333: PetscErrorCode  MatSubMatrixUpdate(Mat N,Mat A,IS isrow,IS iscol)
334: {
335:   PetscErrorCode  ierr;
336:   PetscTruth      flg;
337:   Mat_SubMatrix  *Na;

344:   PetscTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);
345:   if (!flg) SETERRQ(PETSC_ERR_ARG_WRONG,"Matrix has wrong type");

347:   Na = (Mat_SubMatrix*)N->data;
348:   ISEqual(isrow,Na->isrow,&flg);
349:   if (!flg) SETERRQ(PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
350:   ISEqual(iscol,Na->iscol,&flg);
351:   if (!flg) SETERRQ(PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");

353:   PetscObjectReference((PetscObject)A);
354:   MatDestroy(Na->A);
355:   Na->A = A;

357:   Na->scale = 1.0;
358:   if (Na->left) {VecDestroy(Na->left);}
359:   if (Na->right) {VecDestroy(Na->right);}
360:   return(0);
361: }