[37] | 1 | /* How to write MEX files that call LAPACK: |
---|
| 2 | * http://www.mathworks.com/access/helpdesk/help/techdoc/matlab_external/ch04cr17.html |
---|
| 3 | */ |
---|
| 4 | #include "mex.h" |
---|
| 5 | #include <string.h> |
---|
| 6 | |
---|
| 7 | extern int dtrsm(char *side, char *uplo, char *transa, char *diag, |
---|
| 8 | int *m, int *n, double *alpha, double *a, int *lda, |
---|
| 9 | double *b, int *ldb); |
---|
| 10 | |
---|
| 11 | void mexFunction(int nlhs, mxArray *plhs[], |
---|
| 12 | int nrhs, const mxArray *prhs[]) |
---|
| 13 | { |
---|
| 14 | int m,n; |
---|
| 15 | double *T,*b,*x; |
---|
| 16 | char side='L',uplo='U',trans='N',diag='N'; |
---|
| 17 | double one = 1; |
---|
| 18 | |
---|
| 19 | if(nrhs != 2 || nlhs > 1) |
---|
| 20 | mexErrMsgTxt("Usage: x = solve_triu(T,b)"); |
---|
| 21 | |
---|
| 22 | /* prhs[0] is first argument. |
---|
| 23 | * mxGetPr returns double* (data, col-major) |
---|
| 24 | * mxGetM returns int (rows) |
---|
| 25 | * mxGetN returns int (cols) |
---|
| 26 | */ |
---|
| 27 | /* m = rows(T) */ |
---|
| 28 | m = mxGetM(prhs[0]); |
---|
| 29 | n = mxGetN(prhs[0]); |
---|
| 30 | if(m != n) mexErrMsgTxt("matrix must be square"); |
---|
| 31 | /* n = cols(b) */ |
---|
| 32 | n = mxGetN(prhs[1]); |
---|
| 33 | T = mxGetPr(prhs[0]); |
---|
| 34 | b = mxGetPr(prhs[1]); |
---|
| 35 | |
---|
| 36 | if(mxIsSparse(prhs[0]) || mxIsSparse(prhs[1])) { |
---|
| 37 | mexErrMsgTxt("Sorry, can't handle sparse matrices yet."); |
---|
| 38 | } |
---|
| 39 | if(mxGetNumberOfDimensions(prhs[0]) != 2) { |
---|
| 40 | mexErrMsgTxt("Arguments must be matrices."); |
---|
| 41 | } |
---|
| 42 | if(mxGetNumberOfDimensions(prhs[1]) != 2) { |
---|
| 43 | mexErrMsgTxt("Arguments must be matrices."); |
---|
| 44 | } |
---|
| 45 | |
---|
| 46 | /* plhs[0] is first output */ |
---|
| 47 | /* x is same size as b */ |
---|
| 48 | plhs[0] = mxCreateDoubleMatrix(m, n, mxREAL); |
---|
| 49 | x = mxGetPr(plhs[0]); |
---|
| 50 | /* copy b into x to speed up memory access */ |
---|
| 51 | memcpy(x,b,m*n*sizeof(double)); |
---|
| 52 | b = x; |
---|
| 53 | |
---|
| 54 | dtrsm(&side,&uplo,&trans,&diag,&m,&n,&one,T,&m,x,&m); |
---|
| 55 | } |
---|
| 56 | |
---|
| 57 | #if 0 |
---|
| 58 | /* Upper triangular */ |
---|
| 59 | for(j=0;j<n;j++) x[m-1 + m*j] = b[m-1 + m*j]/T[m*m - 1]; |
---|
| 60 | for(i=m-2;i>=0;i--) { |
---|
| 61 | for(j=0;j<n;j++) { |
---|
| 62 | double s = 0; |
---|
| 63 | for(k=i+1;k<m;k++) { |
---|
| 64 | s += T[i + m*k]*x[k + m*j]; |
---|
| 65 | } |
---|
| 66 | x[i + m*j] = (b[i + m*j] - s)/T[i + m*i]; |
---|
| 67 | } |
---|
| 68 | } |
---|
| 69 | /* Lower triangular */ |
---|
| 70 | for(j=0;j<n;j++) x[m*j] = b[m*j]/T[0]; |
---|
| 71 | for(i=1;i<m;i++) { |
---|
| 72 | for(j=0;j<n;j++) { |
---|
| 73 | double s = 0; |
---|
| 74 | for(k=0;k<i;k++) { |
---|
| 75 | s += T[i + m*k]*x[k + m*j]; |
---|
| 76 | } |
---|
| 77 | x[i + m*j] = (b[i + m*j] - s)/T[i + m*i]; |
---|
| 78 | } |
---|
| 79 | } |
---|
| 80 | #endif |
---|