1 | /* see solve_triu for compilation instructions. |
---|
2 | */ |
---|
3 | #include "mex.h" |
---|
4 | #include <string.h> |
---|
5 | |
---|
6 | extern int dtrsm(char *side, char *uplo, char *transa, char *diag, |
---|
7 | int *m, int *n, double *alpha, double *a, int *lda, |
---|
8 | double *b, int *ldb); |
---|
9 | |
---|
10 | void mexFunction(int nlhs, mxArray *plhs[], |
---|
11 | int nrhs, const mxArray *prhs[]) |
---|
12 | { |
---|
13 | int m,n; |
---|
14 | double *T,*b,*x; |
---|
15 | char side='L',uplo='L',trans='N',diag='N'; |
---|
16 | double one = 1; |
---|
17 | |
---|
18 | if(nrhs != 2 || nlhs > 1) |
---|
19 | mexErrMsgTxt("Usage: x = solve_tril(T,b)"); |
---|
20 | |
---|
21 | /* prhs[0] is first argument. |
---|
22 | * mxGetPr returns double* (data, col-major) |
---|
23 | * mxGetM returns int (rows) |
---|
24 | * mxGetN returns int (cols) |
---|
25 | */ |
---|
26 | /* m = rows(T) */ |
---|
27 | m = mxGetM(prhs[0]); |
---|
28 | n = mxGetN(prhs[0]); |
---|
29 | if(m != n) mexErrMsgTxt("matrix must be square"); |
---|
30 | /* n = cols(b) */ |
---|
31 | n = mxGetN(prhs[1]); |
---|
32 | T = mxGetPr(prhs[0]); |
---|
33 | b = mxGetPr(prhs[1]); |
---|
34 | |
---|
35 | if(mxIsSparse(prhs[0]) || mxIsSparse(prhs[1])) { |
---|
36 | mexErrMsgTxt("Sorry, can't handle sparse matrices yet."); |
---|
37 | } |
---|
38 | if(mxGetNumberOfDimensions(prhs[0]) != 2) { |
---|
39 | mexErrMsgTxt("Arguments must be matrices."); |
---|
40 | } |
---|
41 | if(mxGetNumberOfDimensions(prhs[1]) != 2) { |
---|
42 | mexErrMsgTxt("Arguments must be matrices."); |
---|
43 | } |
---|
44 | |
---|
45 | /* plhs[0] is first output */ |
---|
46 | /* x is same size as b */ |
---|
47 | plhs[0] = mxCreateDoubleMatrix(m, n, mxREAL); |
---|
48 | x = mxGetPr(plhs[0]); |
---|
49 | /* copy b into x to speed up memory access */ |
---|
50 | memcpy(x,b,m*n*sizeof(double)); |
---|
51 | b = x; |
---|
52 | |
---|
53 | dtrsm(&side,&uplo,&trans,&diag,&m,&n,&one,T,&m,x,&m); |
---|
54 | } |
---|
55 | |
---|
56 | #if 0 |
---|
57 | /* Upper triangular */ |
---|
58 | for(j=0;j<n;j++) x[m-1 + m*j] = b[m-1 + m*j]/T[m*m - 1]; |
---|
59 | for(i=m-2;i>=0;i--) { |
---|
60 | for(j=0;j<n;j++) { |
---|
61 | double s = 0; |
---|
62 | for(k=i+1;k<m;k++) { |
---|
63 | s += T[i + m*k]*x[k + m*j]; |
---|
64 | } |
---|
65 | x[i + m*j] = (b[i + m*j] - s)/T[i + m*i]; |
---|
66 | } |
---|
67 | } |
---|
68 | /* Lower triangular */ |
---|
69 | for(j=0;j<n;j++) x[m*j] = b[m*j]/T[0]; |
---|
70 | for(i=1;i<m;i++) { |
---|
71 | for(j=0;j<n;j++) { |
---|
72 | double s = 0; |
---|
73 | for(k=0;k<i;k++) { |
---|
74 | s += T[i + m*k]*x[k + m*j]; |
---|
75 | } |
---|
76 | x[i + m*j] = (b[i + m*j] - s)/T[i + m*i]; |
---|
77 | } |
---|
78 | } |
---|
79 | #endif |
---|