[37] | 1 | /* see solve_triu for compilation instructions. |
---|
| 2 | */ |
---|
| 3 | #include "mex.h" |
---|
| 4 | #include <string.h> |
---|
| 5 | |
---|
| 6 | extern int dtrsm(char *side, char *uplo, char *transa, char *diag, |
---|
| 7 | int *m, int *n, double *alpha, double *a, int *lda, |
---|
| 8 | double *b, int *ldb); |
---|
| 9 | |
---|
| 10 | void mexFunction(int nlhs, mxArray *plhs[], |
---|
| 11 | int nrhs, const mxArray *prhs[]) |
---|
| 12 | { |
---|
| 13 | int m,n; |
---|
| 14 | double *T,*b,*x; |
---|
| 15 | char side='L',uplo='L',trans='N',diag='N'; |
---|
| 16 | double one = 1; |
---|
| 17 | |
---|
| 18 | if(nrhs != 2 || nlhs > 1) |
---|
| 19 | mexErrMsgTxt("Usage: x = solve_tril(T,b)"); |
---|
| 20 | |
---|
| 21 | /* prhs[0] is first argument. |
---|
| 22 | * mxGetPr returns double* (data, col-major) |
---|
| 23 | * mxGetM returns int (rows) |
---|
| 24 | * mxGetN returns int (cols) |
---|
| 25 | */ |
---|
| 26 | /* m = rows(T) */ |
---|
| 27 | m = mxGetM(prhs[0]); |
---|
| 28 | n = mxGetN(prhs[0]); |
---|
| 29 | if(m != n) mexErrMsgTxt("matrix must be square"); |
---|
| 30 | /* n = cols(b) */ |
---|
| 31 | n = mxGetN(prhs[1]); |
---|
| 32 | T = mxGetPr(prhs[0]); |
---|
| 33 | b = mxGetPr(prhs[1]); |
---|
| 34 | |
---|
| 35 | if(mxIsSparse(prhs[0]) || mxIsSparse(prhs[1])) { |
---|
| 36 | mexErrMsgTxt("Sorry, can't handle sparse matrices yet."); |
---|
| 37 | } |
---|
| 38 | if(mxGetNumberOfDimensions(prhs[0]) != 2) { |
---|
| 39 | mexErrMsgTxt("Arguments must be matrices."); |
---|
| 40 | } |
---|
| 41 | if(mxGetNumberOfDimensions(prhs[1]) != 2) { |
---|
| 42 | mexErrMsgTxt("Arguments must be matrices."); |
---|
| 43 | } |
---|
| 44 | |
---|
| 45 | /* plhs[0] is first output */ |
---|
| 46 | /* x is same size as b */ |
---|
| 47 | plhs[0] = mxCreateDoubleMatrix(m, n, mxREAL); |
---|
| 48 | x = mxGetPr(plhs[0]); |
---|
| 49 | /* copy b into x to speed up memory access */ |
---|
| 50 | memcpy(x,b,m*n*sizeof(double)); |
---|
| 51 | b = x; |
---|
| 52 | |
---|
| 53 | dtrsm(&side,&uplo,&trans,&diag,&m,&n,&one,T,&m,x,&m); |
---|
| 54 | } |
---|
| 55 | |
---|
| 56 | #if 0 |
---|
| 57 | /* Upper triangular */ |
---|
| 58 | for(j=0;j<n;j++) x[m-1 + m*j] = b[m-1 + m*j]/T[m*m - 1]; |
---|
| 59 | for(i=m-2;i>=0;i--) { |
---|
| 60 | for(j=0;j<n;j++) { |
---|
| 61 | double s = 0; |
---|
| 62 | for(k=i+1;k<m;k++) { |
---|
| 63 | s += T[i + m*k]*x[k + m*j]; |
---|
| 64 | } |
---|
| 65 | x[i + m*j] = (b[i + m*j] - s)/T[i + m*i]; |
---|
| 66 | } |
---|
| 67 | } |
---|
| 68 | /* Lower triangular */ |
---|
| 69 | for(j=0;j<n;j++) x[m*j] = b[m*j]/T[0]; |
---|
| 70 | for(i=1;i<m;i++) { |
---|
| 71 | for(j=0;j<n;j++) { |
---|
| 72 | double s = 0; |
---|
| 73 | for(k=0;k<i;k++) { |
---|
| 74 | s += T[i + m*k]*x[k + m*j]; |
---|
| 75 | } |
---|
| 76 | x[i + m*j] = (b[i + m*j] - s)/T[i + m*i]; |
---|
| 77 | } |
---|
| 78 | } |
---|
| 79 | #endif |
---|