source: proiecte/Parallel-DT/R8/Src/prune.c @ 24

Last change on this file since 24 was 24, checked in by (none), 14 years ago

blabla

File size: 6.2 KB
Line 
1/*************************************************************************/
2/*                                                                       */
3/*      Prune a decision tree and predict its error rate                 */
4/*      ------------------------------------------------                 */
5/*                                                                       */
6/*************************************************************************/
7
8
9#include "defns.i"
10#include "types.i"
11#include "extern.i"
12
13
14extern  ItemCount       *Weight;
15
16Set     *PossibleValues=Nil;
17Boolean Changed;
18
19#define LocalVerbosity(x)       if (Sh >= 0 && VERBOSITY >= x)
20#define Intab(x)                Indent(x, "| ")
21
22
23
24/*************************************************************************/
25/*                                                                       */
26/*  Prune tree T, returning true if tree has been modified               */
27/*                                                                       */
28/*************************************************************************/
29
30
31Boolean Prune(T)
32/*      -----  */
33    Tree T;
34{
35    ItemNo i;
36    float EstimateErrors();
37    Attribute a;
38
39    InitialiseWeights();
40    AllKnown = true;
41
42    Verbosity(1) printf("\n");
43
44    Changed = false;
45
46    EstimateErrors(T, 0, MaxItem, 0, true);
47
48    if ( SUBSET )
49    {
50        if ( ! PossibleValues )
51        {
52            PossibleValues = (Set *) calloc(MaxAtt+1, sizeof(Set));
53        }
54
55        ForEach(a, 0, MaxAtt)
56        {
57            if ( MaxAttVal[a] )
58            {
59                PossibleValues[a] = (Set) malloc((MaxAttVal[a]>>3) + 1);
60                ClearBits((MaxAttVal[a]>>3) + 1, PossibleValues[a]);
61                ForEach(i, 1, MaxAttVal[a])
62                {
63                    SetBit(i, PossibleValues[a]);
64                }
65            }
66        }
67
68        CheckPossibleValues(T);
69    }
70
71    return Changed;
72}
73
74
75
76
77/*************************************************************************/
78/*                                                                       */
79/*      Estimate the errors in a given subtree                           */
80/*                                                                       */
81/*************************************************************************/
82
83
84float EstimateErrors(T, Fp, Lp, Sh, UpdateTree)
85/*    --------------  */
86    Tree T;
87    ItemNo Fp, Lp; 
88    short Sh;
89    Boolean UpdateTree;
90{ 
91    ItemNo i, Kp, Ep, Group();
92    ItemCount Cases, KnownCases, *LocalClassDist, TreeErrors, LeafErrors,
93        ExtraLeafErrors, BranchErrors, CountItems(), Factor, MaxFactor, AddErrs();
94    DiscrValue v, MaxBr;
95    ClassNo c, BestClass;
96    Boolean PrevAllKnown;
97
98    /*  Generate the class frequency distribution  */
99
100    Cases = CountItems(Fp, Lp);
101    LocalClassDist = (ItemCount *) calloc(MaxClass+1, sizeof(ItemCount));
102
103    ForEach(i, Fp, Lp)
104    { 
105        LocalClassDist[ Class(Item[i]) ] += Weight[i];
106    } 
107
108    /*  Find the most frequent class and update the tree  */
109
110    BestClass = T->Leaf;
111    ForEach(c, 0, MaxClass)
112    {
113        if ( LocalClassDist[c] > LocalClassDist[BestClass] )
114        {
115            BestClass = c;
116        }
117    }
118    LeafErrors = Cases - LocalClassDist[BestClass];
119    ExtraLeafErrors = AddErrs(Cases, LeafErrors);
120
121    if ( UpdateTree )
122    {
123        T->Items = Cases;
124        T->Leaf  = BestClass;
125        memcpy(T->ClassDist, LocalClassDist, (MaxClass + 1) * sizeof(ItemCount));
126    }
127
128    if ( ! T->NodeType )        /*  leaf  */
129    {
130        TreeErrors = LeafErrors + ExtraLeafErrors;
131
132        if ( UpdateTree )
133        {
134            T->Errors = TreeErrors;
135
136            LocalVerbosity(1)
137            {
138                Intab(Sh);
139                printf("%s (%.2f:%.2f/%.2f)\n", ClassName[T->Leaf],
140                        T->Items, LeafErrors, T->Errors);
141            }
142        }
143
144        free(LocalClassDist);
145
146        return TreeErrors;
147    }
148
149    /*  Estimate errors for each branch  */
150
151    Kp = Group(0, Fp, Lp, T) + 1;
152    KnownCases = CountItems(Kp, Lp);
153
154    PrevAllKnown = AllKnown;
155    if ( Kp != Fp ) AllKnown = false;
156
157    TreeErrors = MaxFactor = 0;
158
159    ForEach(v, 1, T->Forks)
160    {
161        Ep = Group(v, Kp, Lp, T);
162
163        if ( Kp <= Ep )
164        {
165            Factor = CountItems(Kp, Ep) / KnownCases;
166
167            if ( Factor >= MaxFactor )
168            {
169                MaxBr = v;
170                MaxFactor = Factor;
171            }
172
173            ForEach(i, Fp, Kp-1)
174            {
175                Weight[i] *= Factor;
176            }
177
178            TreeErrors += EstimateErrors(T->Branch[v], Fp, Ep, Sh+1, UpdateTree);
179
180            Group(0, Fp, Ep, T);
181            ForEach(i, Fp, Kp-1)
182            {
183                Weight[i] /= Factor;
184            }
185        }
186    }
187 
188    AllKnown = PrevAllKnown;
189
190    if ( ! UpdateTree )
191    {
192        free(LocalClassDist);
193
194        return TreeErrors;
195    }
196
197    /*  See how the largest branch would fare  */
198
199    BranchErrors = EstimateErrors(T->Branch[MaxBr], Fp, Lp, -1000, false);
200
201    LocalVerbosity(1)
202    {
203        Intab(Sh);
204        printf("%s:  [%d%%  N=%.2f  tree=%.2f  leaf=%.2f+%.2f  br[%d]=%.2f]\n",
205                AttName[T->Tested],
206                (int) ((TreeErrors * 100) / (T->Items + 0.001)),
207                T->Items, TreeErrors, LeafErrors, ExtraLeafErrors,
208                MaxBr, BranchErrors);
209    }
210
211    /*  See whether tree should be replaced with leaf or largest branch  */
212
213    if ( LeafErrors + ExtraLeafErrors <= BranchErrors + 0.1 &&
214         LeafErrors + ExtraLeafErrors <= TreeErrors + 0.1 )
215    {
216        LocalVerbosity(1)
217        {
218            Intab(Sh);
219            printf("Replaced with leaf %s\n", ClassName[T->Leaf]);
220        }
221
222        T->NodeType = 0;
223        T->Errors = LeafErrors + ExtraLeafErrors;
224        Changed = true;
225    }
226    else
227    if ( BranchErrors <= TreeErrors + 0.1 )
228    {
229        LocalVerbosity(1)
230        {
231            Intab(Sh);
232            printf("Replaced with branch %d\n", MaxBr);
233        }
234
235        AllKnown = PrevAllKnown;
236        EstimateErrors(T->Branch[MaxBr], Fp, Lp, Sh, true);
237        memcpy((char *) T, (char *) T->Branch[MaxBr], sizeof(TreeRec));
238        Changed = true;
239    }
240    else
241    {
242        T->Errors = TreeErrors;
243    }
244
245    AllKnown = PrevAllKnown;
246    free(LocalClassDist);
247
248    return T->Errors;
249}
250
251
252
253/*************************************************************************/
254/*                                                                       */
255/*      Remove unnecessary subset tests on missing values                */
256/*                                                                       */
257/*************************************************************************/
258
259
260    CheckPossibleValues(T)
261/*  -------------------  */
262    Tree T;
263{
264    Set HoldValues;
265    int v, Bytes, b;
266    Attribute A;
267    char Any=0;
268
269    if ( T->NodeType == BrSubset )
270    {
271        A = T->Tested;
272
273        Bytes = (MaxAttVal[A]>>3) + 1;
274        HoldValues = (Set) malloc(Bytes);
275
276        /*  See if last (default) branch can be simplified or omitted  */
277
278        ForEach(b, 0, Bytes-1)
279        {
280            T->Subset[T->Forks][b] &= PossibleValues[A][b];
281            Any |= T->Subset[T->Forks][b];
282        }
283
284        if ( ! Any )
285        {
286            T->Forks--;
287        }
288
289        /*  Process each subtree, leaving only values in branch subset  */
290
291        CopyBits(Bytes, PossibleValues[A], HoldValues);
292
293        ForEach(v, 1, T->Forks)
294        {
295            CopyBits(Bytes, T->Subset[v], PossibleValues[A]);
296
297            CheckPossibleValues(T->Branch[v]);
298        }
299
300        CopyBits(Bytes, HoldValues, PossibleValues[A]);
301
302        free(HoldValues);
303    }
304    else
305    if ( T->NodeType )
306    {
307        ForEach(v, 1, T->Forks)
308        {
309            CheckPossibleValues(T->Branch[v]);
310        }
311    }
312}
Note: See TracBrowser for help on using the repository browser.