1
2
3
4 """ Contains functionality for doing tree pruning
5
6 """
7 import numpy
8 from ML.DecTree import CrossValidate, DecTree
9 import copy
10
11 _verbose = 0
12
14 """ given a set of examples, returns the most common result code
15
16 **Arguments**
17
18 examples: a list of examples to be counted
19
20 **Returns**
21
22 the most common result code
23
24 """
25 resList = [x[-1] for x in examples]
26 maxVal = max(resList)
27 counts = [None]*(maxVal+1)
28 for i in xrange(maxVal+1):
29 counts[i] = sum([x==i for x in resList])
30
31 return numpy.argmax(counts)
32
34 nWrong = 0
35 for example in node.GetExamples():
36 pred = node.ClassifyExample(example,appendExamples=0)
37 if pred != example[-1]:
38 nWrong +=1
39
40 return nWrong
41
43 """Recursively finds and removes the nodes whose removals improve classification
44
45 **Arguments**
46
47 - node: the tree to be pruned. The pruning data should already be contained
48 within node (i.e. node.GetExamples() should return the pruning data)
49
50 - level: (optional) the level of recursion, used only in _verbose printing
51
52
53 **Returns**
54
55 the pruned version of node
56
57
58 **Notes**
59
60 - This uses a greedy algorithm which basically does a DFS traversal of the tree,
61 removing nodes whenever possible.
62
63 - If removing a node does not affect the accuracy, it *will be* removed. We
64 favor smaller trees.
65
66 """
67 if _verbose: print ' '*level,'<%d> '%level,'>>> Pruner'
68 children = node.GetChildren()[:]
69
70 bestTree = copy.deepcopy(node)
71 bestErr = 1e6
72 emptyChildren=[]
73
74
75
76
77
78 for i in range(len(children)):
79 child = children[i]
80 examples = child.GetExamples()
81 if _verbose:
82 print ' '*level,'<%d> '%level,' Child:',i,child.GetLabel()
83 bestTree.Print()
84 print
85 if len(examples):
86 if _verbose: print ' '*level,'<%d> '%level,' Examples',len(examples)
87 if not child.GetTerminal():
88 if _verbose: print ' '*level,'<%d> '%level,' Nonterminal'
89
90 workTree = copy.deepcopy(bestTree)
91
92
93
94 newNode = _Pruner(child,level=level+1)
95 workTree.ReplaceChildIndex(i,newNode)
96 tempErr = _GetLocalError(workTree)
97 if tempErr<=bestErr:
98 bestErr = tempErr
99 bestTree = copy.deepcopy(workTree)
100 if _verbose:
101 print ' '*level,'<%d> '%level,'>->->->->->'
102 print ' '*level,'<%d> '%level,'replacing:',i,child.GetLabel()
103 child.Print()
104 print ' '*level,'<%d> '%level,'with:'
105 newNode.Print()
106 print ' '*level,'<%d> '%level,'<-<-<-<-<-<'
107 else:
108 workTree.ReplaceChildIndex(i,child)
109
110
111
112 bestGuess = MaxCount(child.GetExamples())
113 newNode = DecTree.DecTreeNode(workTree,'L:%d'%(bestGuess),
114 label=bestGuess,isTerminal=1)
115 newNode.SetExamples(child.GetExamples())
116 workTree.ReplaceChildIndex(i,newNode)
117 if _verbose:
118 print ' '*level,'<%d> '%level,'ATTEMPT:'
119 workTree.Print()
120 newErr = _GetLocalError(workTree)
121 if _verbose: print ' '*level,'<%d> '%level,'---> ',newErr,bestErr
122 if newErr <= bestErr:
123 bestErr = newErr
124 bestTree = copy.deepcopy(workTree)
125 if _verbose:
126 print ' '*level,'<%d> '%level,'PRUNING:'
127 workTree.Print()
128 else:
129 if _verbose: print ' '*level,'<%d> '%level,'FAIL'
130
131 workTree.ReplaceChildIndex(i,child)
132 else:
133 if _verbose: print ' '*level,'<%d> '%level,' Terminal'
134 else:
135 if _verbose: print ' '*level,'<%d> '%level,' No Examples',len(examples)
136
137
138
139
140
141
142 pass
143
144 if _verbose: print ' '*level,'<%d> '%level,'<<< out'
145 return bestTree
146
147 -def PruneTree(tree,trainExamples,testExamples,minimizeTestErrorOnly=1):
148 """ implements a reduced-error pruning of decision trees
149
150 This algorithm is described on page 69 of Mitchell's book.
151
152 Pruning can be done using just the set of testExamples (the validation set)
153 or both the testExamples and the trainExamples by setting minimizeTestErrorOnly
154 to 0.
155
156 **Arguments**
157
158 - tree: the initial tree to be pruned
159
160 - trainExamples: the examples used to train the tree
161
162 - testExamples: the examples held out for testing the tree
163
164 - minimizeTestErrorOnly: if this toggle is zero, all examples (i.e.
165 _trainExamples_ + _testExamples_ will be used to evaluate the error.
166
167 **Returns**
168
169 a 2-tuple containing:
170
171 1) the best tree
172
173 2) the best error (the one which corresponds to that tree)
174
175 """
176 if minimizeTestErrorOnly:
177 testSet = testExamples
178 else:
179 testSet = trainExamples + testExamples
180
181
182 tree.ClearExamples()
183
184
185
186
187
188 totErr,badEx = CrossValidate.CrossValidate(tree,testSet,appendExamples=1)
189
190
191
192
193
194 newTree = _Pruner(tree)
195
196
197
198
199 totErr,badEx = CrossValidate.CrossValidate(newTree,testSet)
200 newTree.SetBadExamples(badEx)
201
202 return newTree,totErr
203
204
205
206
207
209 from ML.DecTree import randomtest
210
211 examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nVars=10,randScale=0.5,nExamples = 200)
212 tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals)
213 tree.Print()
214 tree.Pickle('orig.pkl')
215 print 'original error is:', frac
216
217 print '----Pruning'
218 newTree,frac2 = PruneTree(tree,tree.GetTrainingExamples(),tree.GetTestExamples())
219 newTree.Print()
220 print 'pruned error is:',frac2
221 newTree.Pickle('prune.pkl')
222
223
225 from ML.DecTree import ID3
226 oPts= [ \
227 [0,0,1,0],
228 [0,1,1,1],
229 [1,0,1,1],
230 [1,1,0,0],
231 [1,1,1,1],
232 ]
233 tPts = oPts+[[0,1,1,0],[0,1,1,0]]
234
235 tree = ID3.ID3Boot(oPts,attrs=range(3),nPossibleVals=[2]*4)
236 tree.Print()
237 err,badEx = CrossValidate.CrossValidate(tree,oPts)
238 print 'original error:',err
239
240
241 err,badEx = CrossValidate.CrossValidate(tree,tPts)
242 print 'original holdout error:',err
243 newTree,frac2 = PruneTree(tree,oPts,tPts)
244 newTree.Print()
245 err,badEx = CrossValidate.CrossValidate(newTree,tPts)
246 print 'pruned holdout error is:',err
247 print badEx
248
249 print len(tree),len(newTree)
250
252 from ML.DecTree import ID3
253 oPts= [ \
254 [1,0,0,0,1],
255 [1,0,0,0,1],
256 [1,0,0,0,1],
257 [1,0,0,0,1],
258 [1,0,0,0,1],
259 [1,0,0,0,1],
260 [1,0,0,0,1],
261 [0,0,1,1,0],
262 [0,0,1,1,0],
263 [0,0,1,1,1],
264 [0,1,0,1,0],
265 [0,1,0,1,0],
266 [0,1,0,0,1],
267 ]
268 tPts = oPts
269
270 tree = ID3.ID3Boot(oPts,attrs=range(len(oPts[0])-1),nPossibleVals=[2]*len(oPts[0]))
271 tree.Print()
272 err,badEx = CrossValidate.CrossValidate(tree,oPts)
273 print 'original error:',err
274
275
276 err,badEx = CrossValidate.CrossValidate(tree,tPts)
277 print 'original holdout error:',err
278 newTree,frac2 = PruneTree(tree,oPts,tPts)
279 newTree.Print()
280 err,badEx = CrossValidate.CrossValidate(newTree,tPts)
281 print 'pruned holdout error is:',err
282 print badEx
283
284
285 if __name__ == '__main__':
286 _verbose=1
287
288
289 _testChain()
290