1 | /*************************************************************************/ |
---|
2 | /* */ |
---|
3 | /* Prune a decision tree and predict its error rate */ |
---|
4 | /* ------------------------------------------------ */ |
---|
5 | /* */ |
---|
6 | /*************************************************************************/ |
---|
7 | |
---|
8 | |
---|
9 | #include "defns.i" |
---|
10 | #include "types.i" |
---|
11 | #include "extern.i" |
---|
12 | |
---|
13 | |
---|
14 | extern ItemCount *Weight; |
---|
15 | |
---|
16 | Set *PossibleValues=Nil; |
---|
17 | Boolean Changed; |
---|
18 | |
---|
19 | #define LocalVerbosity(x) if (Sh >= 0 && VERBOSITY >= x) |
---|
20 | #define Intab(x) Indent(x, "| ") |
---|
21 | |
---|
22 | |
---|
23 | |
---|
24 | /*************************************************************************/ |
---|
25 | /* */ |
---|
26 | /* Prune tree T, returning true if tree has been modified */ |
---|
27 | /* */ |
---|
28 | /*************************************************************************/ |
---|
29 | |
---|
30 | |
---|
31 | Boolean Prune(T) |
---|
32 | /* ----- */ |
---|
33 | Tree T; |
---|
34 | { |
---|
35 | ItemNo i; |
---|
36 | float EstimateErrors(); |
---|
37 | Attribute a; |
---|
38 | |
---|
39 | InitialiseWeights(); |
---|
40 | AllKnown = true; |
---|
41 | |
---|
42 | Verbosity(1) printf("\n"); |
---|
43 | |
---|
44 | Changed = false; |
---|
45 | |
---|
46 | EstimateErrors(T, 0, MaxItem, 0, true); |
---|
47 | |
---|
48 | if ( SUBSET ) |
---|
49 | { |
---|
50 | if ( ! PossibleValues ) |
---|
51 | { |
---|
52 | PossibleValues = (Set *) calloc(MaxAtt+1, sizeof(Set)); |
---|
53 | } |
---|
54 | |
---|
55 | ForEach(a, 0, MaxAtt) |
---|
56 | { |
---|
57 | if ( MaxAttVal[a] ) |
---|
58 | { |
---|
59 | PossibleValues[a] = (Set) malloc((MaxAttVal[a]>>3) + 1); |
---|
60 | ClearBits((MaxAttVal[a]>>3) + 1, PossibleValues[a]); |
---|
61 | ForEach(i, 1, MaxAttVal[a]) |
---|
62 | { |
---|
63 | SetBit(i, PossibleValues[a]); |
---|
64 | } |
---|
65 | } |
---|
66 | } |
---|
67 | |
---|
68 | CheckPossibleValues(T); |
---|
69 | } |
---|
70 | |
---|
71 | return Changed; |
---|
72 | } |
---|
73 | |
---|
74 | |
---|
75 | |
---|
76 | |
---|
77 | /*************************************************************************/ |
---|
78 | /* */ |
---|
79 | /* Estimate the errors in a given subtree */ |
---|
80 | /* */ |
---|
81 | /*************************************************************************/ |
---|
82 | |
---|
83 | |
---|
84 | float EstimateErrors(T, Fp, Lp, Sh, UpdateTree) |
---|
85 | /* -------------- */ |
---|
86 | Tree T; |
---|
87 | ItemNo Fp, Lp; |
---|
88 | short Sh; |
---|
89 | Boolean UpdateTree; |
---|
90 | { |
---|
91 | ItemNo i, Kp, Ep, Group(); |
---|
92 | ItemCount Cases, KnownCases, *LocalClassDist, TreeErrors, LeafErrors, |
---|
93 | ExtraLeafErrors, BranchErrors, CountItems(), Factor, MaxFactor, AddErrs(); |
---|
94 | DiscrValue v, MaxBr; |
---|
95 | ClassNo c, BestClass; |
---|
96 | Boolean PrevAllKnown; |
---|
97 | |
---|
98 | /* Generate the class frequency distribution */ |
---|
99 | |
---|
100 | Cases = CountItems(Fp, Lp); |
---|
101 | LocalClassDist = (ItemCount *) calloc(MaxClass+1, sizeof(ItemCount)); |
---|
102 | |
---|
103 | ForEach(i, Fp, Lp) |
---|
104 | { |
---|
105 | LocalClassDist[ Class(Item[i]) ] += Weight[i]; |
---|
106 | } |
---|
107 | |
---|
108 | /* Find the most frequent class and update the tree */ |
---|
109 | |
---|
110 | BestClass = T->Leaf; |
---|
111 | ForEach(c, 0, MaxClass) |
---|
112 | { |
---|
113 | if ( LocalClassDist[c] > LocalClassDist[BestClass] ) |
---|
114 | { |
---|
115 | BestClass = c; |
---|
116 | } |
---|
117 | } |
---|
118 | LeafErrors = Cases - LocalClassDist[BestClass]; |
---|
119 | ExtraLeafErrors = AddErrs(Cases, LeafErrors); |
---|
120 | |
---|
121 | if ( UpdateTree ) |
---|
122 | { |
---|
123 | T->Items = Cases; |
---|
124 | T->Leaf = BestClass; |
---|
125 | memcpy(T->ClassDist, LocalClassDist, (MaxClass + 1) * sizeof(ItemCount)); |
---|
126 | } |
---|
127 | |
---|
128 | if ( ! T->NodeType ) /* leaf */ |
---|
129 | { |
---|
130 | TreeErrors = LeafErrors + ExtraLeafErrors; |
---|
131 | |
---|
132 | if ( UpdateTree ) |
---|
133 | { |
---|
134 | T->Errors = TreeErrors; |
---|
135 | |
---|
136 | LocalVerbosity(1) |
---|
137 | { |
---|
138 | Intab(Sh); |
---|
139 | printf("%s (%.2f:%.2f/%.2f)\n", ClassName[T->Leaf], |
---|
140 | T->Items, LeafErrors, T->Errors); |
---|
141 | } |
---|
142 | } |
---|
143 | |
---|
144 | free(LocalClassDist); |
---|
145 | |
---|
146 | return TreeErrors; |
---|
147 | } |
---|
148 | |
---|
149 | /* Estimate errors for each branch */ |
---|
150 | |
---|
151 | Kp = Group(0, Fp, Lp, T) + 1; |
---|
152 | KnownCases = CountItems(Kp, Lp); |
---|
153 | |
---|
154 | PrevAllKnown = AllKnown; |
---|
155 | if ( Kp != Fp ) AllKnown = false; |
---|
156 | |
---|
157 | TreeErrors = MaxFactor = 0; |
---|
158 | |
---|
159 | ForEach(v, 1, T->Forks) |
---|
160 | { |
---|
161 | Ep = Group(v, Kp, Lp, T); |
---|
162 | |
---|
163 | if ( Kp <= Ep ) |
---|
164 | { |
---|
165 | Factor = CountItems(Kp, Ep) / KnownCases; |
---|
166 | |
---|
167 | if ( Factor >= MaxFactor ) |
---|
168 | { |
---|
169 | MaxBr = v; |
---|
170 | MaxFactor = Factor; |
---|
171 | } |
---|
172 | |
---|
173 | ForEach(i, Fp, Kp-1) |
---|
174 | { |
---|
175 | Weight[i] *= Factor; |
---|
176 | } |
---|
177 | |
---|
178 | TreeErrors += EstimateErrors(T->Branch[v], Fp, Ep, Sh+1, UpdateTree); |
---|
179 | |
---|
180 | Group(0, Fp, Ep, T); |
---|
181 | ForEach(i, Fp, Kp-1) |
---|
182 | { |
---|
183 | Weight[i] /= Factor; |
---|
184 | } |
---|
185 | } |
---|
186 | } |
---|
187 | |
---|
188 | AllKnown = PrevAllKnown; |
---|
189 | |
---|
190 | if ( ! UpdateTree ) |
---|
191 | { |
---|
192 | free(LocalClassDist); |
---|
193 | |
---|
194 | return TreeErrors; |
---|
195 | } |
---|
196 | |
---|
197 | /* See how the largest branch would fare */ |
---|
198 | |
---|
199 | BranchErrors = EstimateErrors(T->Branch[MaxBr], Fp, Lp, -1000, false); |
---|
200 | |
---|
201 | LocalVerbosity(1) |
---|
202 | { |
---|
203 | Intab(Sh); |
---|
204 | printf("%s: [%d%% N=%.2f tree=%.2f leaf=%.2f+%.2f br[%d]=%.2f]\n", |
---|
205 | AttName[T->Tested], |
---|
206 | (int) ((TreeErrors * 100) / (T->Items + 0.001)), |
---|
207 | T->Items, TreeErrors, LeafErrors, ExtraLeafErrors, |
---|
208 | MaxBr, BranchErrors); |
---|
209 | } |
---|
210 | |
---|
211 | /* See whether tree should be replaced with leaf or largest branch */ |
---|
212 | |
---|
213 | if ( LeafErrors + ExtraLeafErrors <= BranchErrors + 0.1 && |
---|
214 | LeafErrors + ExtraLeafErrors <= TreeErrors + 0.1 ) |
---|
215 | { |
---|
216 | LocalVerbosity(1) |
---|
217 | { |
---|
218 | Intab(Sh); |
---|
219 | printf("Replaced with leaf %s\n", ClassName[T->Leaf]); |
---|
220 | } |
---|
221 | |
---|
222 | T->NodeType = 0; |
---|
223 | T->Errors = LeafErrors + ExtraLeafErrors; |
---|
224 | Changed = true; |
---|
225 | } |
---|
226 | else |
---|
227 | if ( BranchErrors <= TreeErrors + 0.1 ) |
---|
228 | { |
---|
229 | LocalVerbosity(1) |
---|
230 | { |
---|
231 | Intab(Sh); |
---|
232 | printf("Replaced with branch %d\n", MaxBr); |
---|
233 | } |
---|
234 | |
---|
235 | AllKnown = PrevAllKnown; |
---|
236 | EstimateErrors(T->Branch[MaxBr], Fp, Lp, Sh, true); |
---|
237 | memcpy((char *) T, (char *) T->Branch[MaxBr], sizeof(TreeRec)); |
---|
238 | Changed = true; |
---|
239 | } |
---|
240 | else |
---|
241 | { |
---|
242 | T->Errors = TreeErrors; |
---|
243 | } |
---|
244 | |
---|
245 | AllKnown = PrevAllKnown; |
---|
246 | free(LocalClassDist); |
---|
247 | |
---|
248 | return T->Errors; |
---|
249 | } |
---|
250 | |
---|
251 | |
---|
252 | |
---|
253 | /*************************************************************************/ |
---|
254 | /* */ |
---|
255 | /* Remove unnecessary subset tests on missing values */ |
---|
256 | /* */ |
---|
257 | /*************************************************************************/ |
---|
258 | |
---|
259 | |
---|
260 | CheckPossibleValues(T) |
---|
261 | /* ------------------- */ |
---|
262 | Tree T; |
---|
263 | { |
---|
264 | Set HoldValues; |
---|
265 | int v, Bytes, b; |
---|
266 | Attribute A; |
---|
267 | char Any=0; |
---|
268 | |
---|
269 | if ( T->NodeType == BrSubset ) |
---|
270 | { |
---|
271 | A = T->Tested; |
---|
272 | |
---|
273 | Bytes = (MaxAttVal[A]>>3) + 1; |
---|
274 | HoldValues = (Set) malloc(Bytes); |
---|
275 | |
---|
276 | /* See if last (default) branch can be simplified or omitted */ |
---|
277 | |
---|
278 | ForEach(b, 0, Bytes-1) |
---|
279 | { |
---|
280 | T->Subset[T->Forks][b] &= PossibleValues[A][b]; |
---|
281 | Any |= T->Subset[T->Forks][b]; |
---|
282 | } |
---|
283 | |
---|
284 | if ( ! Any ) |
---|
285 | { |
---|
286 | T->Forks--; |
---|
287 | } |
---|
288 | |
---|
289 | /* Process each subtree, leaving only values in branch subset */ |
---|
290 | |
---|
291 | CopyBits(Bytes, PossibleValues[A], HoldValues); |
---|
292 | |
---|
293 | ForEach(v, 1, T->Forks) |
---|
294 | { |
---|
295 | CopyBits(Bytes, T->Subset[v], PossibleValues[A]); |
---|
296 | |
---|
297 | CheckPossibleValues(T->Branch[v]); |
---|
298 | } |
---|
299 | |
---|
300 | CopyBits(Bytes, HoldValues, PossibleValues[A]); |
---|
301 | |
---|
302 | free(HoldValues); |
---|
303 | } |
---|
304 | else |
---|
305 | if ( T->NodeType ) |
---|
306 | { |
---|
307 | ForEach(v, 1, T->Forks) |
---|
308 | { |
---|
309 | CheckPossibleValues(T->Branch[v]); |
---|
310 | } |
---|
311 | } |
---|
312 | } |
---|