1 | /* mex ndsumC.c flops.obj |
---|
2 | */ |
---|
3 | /* Written by Tom Minka |
---|
4 | * (c) Microsoft Corporation. All rights reserved. |
---|
5 | */ |
---|
6 | #include "mex.h" |
---|
7 | #include "flops.h" |
---|
8 | |
---|
9 | void ndsum(double *dest, double *src, unsigned ndim, unsigned *size, |
---|
10 | unsigned total, unsigned *mask) |
---|
11 | { |
---|
12 | /* dest is the destination array. |
---|
13 | * src is the source array. |
---|
14 | * ndim is the number of dimensions in the source array (the length of size). |
---|
15 | * size[i] is the size of dimension i in the source array. |
---|
16 | * mask[i] indicates that dimension i is to be summed out. |
---|
17 | */ |
---|
18 | unsigned *masked_size,*cum_masked_size,*advance,*rewind; |
---|
19 | int *subs = mxCalloc(ndim,sizeof(int)); |
---|
20 | int i,j; |
---|
21 | masked_size = mxCalloc(ndim,sizeof(int)); |
---|
22 | for(i=0;i<ndim;i++) { |
---|
23 | masked_size[i] = mask[i] ? 1 : size[i]; |
---|
24 | } |
---|
25 | /* subs (initialized to zero) is the current position in the source array. |
---|
26 | * cum_masked_size[i] is the cumulative product of masked sizes up to i. |
---|
27 | * rewind[i] is cum_masked_size[i]*(masked_size[i]-1) |
---|
28 | * |
---|
29 | * The position in the destination array is |
---|
30 | * sum_i subs[i]*cum_masked_size[i]*!mask[i]. |
---|
31 | * To compute this incrementally, we add cum_masked_size[i]*!mask[i] |
---|
32 | * whenever we increment subs[i], and we subtract rewind[i] |
---|
33 | * whenever we set subs[i] back to 0 (from size[i]-1). |
---|
34 | */ |
---|
35 | cum_masked_size = mxCalloc(ndim,sizeof(int)); |
---|
36 | advance = mxCalloc(ndim,sizeof(int)); |
---|
37 | rewind = mxCalloc(ndim,sizeof(int)); |
---|
38 | cum_masked_size[0] = 1; |
---|
39 | for(i=0;i<ndim-1;i++) { |
---|
40 | cum_masked_size[i+1] = cum_masked_size[i]*masked_size[i]; |
---|
41 | } |
---|
42 | for(i=0;i<ndim;i++) { |
---|
43 | advance[i] = cum_masked_size[i]*!mask[i]; |
---|
44 | rewind[i] = cum_masked_size[i]*(masked_size[i]-1); |
---|
45 | } |
---|
46 | |
---|
47 | for(j=0;j<total;j++) { |
---|
48 | *dest += *src++; |
---|
49 | /* increment subs and update dest */ |
---|
50 | for(i=0;i<ndim;i++) { |
---|
51 | if(subs[i] == size[i]-1) { |
---|
52 | subs[i] = 0; |
---|
53 | dest -= rewind[i]; |
---|
54 | } |
---|
55 | else { |
---|
56 | subs[i]++; |
---|
57 | dest += advance[i]; |
---|
58 | break; |
---|
59 | } |
---|
60 | } |
---|
61 | } |
---|
62 | mxFree(rewind); |
---|
63 | mxFree(advance); |
---|
64 | mxFree(cum_masked_size); |
---|
65 | mxFree(masked_size); |
---|
66 | mxFree(subs); |
---|
67 | } |
---|
68 | |
---|
69 | void mexFunction(int nlhs, mxArray *plhs[], |
---|
70 | int nrhs, const mxArray *prhs[]) |
---|
71 | { |
---|
72 | /* prhs[0] is the multidimensional array. |
---|
73 | * prhs[1] is a vector listing the dimensions to sum out. |
---|
74 | */ |
---|
75 | int i,j; |
---|
76 | int ndim = mxGetNumberOfDimensions(prhs[0]); |
---|
77 | const int *size = mxGetDimensions(prhs[0]); |
---|
78 | double *src = mxGetPr(prhs[0]), *dest; |
---|
79 | double *dim = mxGetPr(prhs[1]); |
---|
80 | int nremove = mxGetNumberOfElements(prhs[1]); |
---|
81 | int dest_ndim = ndim - nremove; |
---|
82 | unsigned *dest_size; |
---|
83 | unsigned *mask; |
---|
84 | int total = mxGetNumberOfElements(prhs[0]), dest_total; |
---|
85 | |
---|
86 | if(dest_ndim == 0) { |
---|
87 | /* sum out everything */ |
---|
88 | plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL); |
---|
89 | dest = mxGetPr(plhs[0]); |
---|
90 | for(i=0;i<total;i++) { |
---|
91 | *dest += src[i]; |
---|
92 | } |
---|
93 | addflops(total-1); |
---|
94 | return; |
---|
95 | } |
---|
96 | |
---|
97 | /* convert dim to a mask. */ |
---|
98 | /* mask[i] = 1 if dimension i is to be removed. */ |
---|
99 | mask = mxCalloc(ndim,sizeof(int)); |
---|
100 | /* dimensions in dim are indexed starting from 1. */ |
---|
101 | for(i=0;i<nremove;i++) mask[(int)dim[i]-1] = 1; |
---|
102 | /* compute the new size vector. */ |
---|
103 | dest_size = mxCalloc(dest_ndim,sizeof(int)); |
---|
104 | j = 0; |
---|
105 | for(i=0;i<ndim;i++) { |
---|
106 | if(!mask[i]) dest_size[j++] = size[i]; |
---|
107 | } |
---|
108 | plhs[0] = mxCreateNumericArray(dest_ndim, dest_size, mxDOUBLE_CLASS, mxREAL); |
---|
109 | dest_total = mxGetNumberOfElements(plhs[0]); |
---|
110 | dest = mxGetPr(plhs[0]); |
---|
111 | addflops(total-dest_total); |
---|
112 | if(dest_total == total) { |
---|
113 | /* don't sum anything */ |
---|
114 | memcpy(dest,src,total*sizeof(double)); |
---|
115 | return; |
---|
116 | } |
---|
117 | ndsum(dest,src,ndim,size,total,mask); |
---|
118 | } |
---|