1
2
3
4 """ code for dealing with forests (collections) of decision trees
5
6 **NOTE** This code should be obsolete now that ML.Composite.Composite is up and running.
7
8 """
9 import cPickle
10 import numpy
11 from ML.DecTree import CrossValidate,PruneTree
12
14 """a forest of unique decision trees.
15
16 adding an existing tree just results in its count field being incremented
17 and the errors being averaged.
18
19 typical usage:
20
21 1) grow the forest with AddTree until happy with it
22
23 2) call AverageErrors to calculate the average error values
24
25 3) call SortTrees to put things in order by either error or count
26
27 """
29 """ creates a histogram of error/count pairs
30
31 """
32 nExamples = len(self.treeList)
33 histo = []
34 i = 1
35 lastErr = self.errList[0]
36 countHere = self.countList[0]
37 eps = 0.001
38 while i < nExamples:
39 if self.errList[i]-lastErr > eps:
40 histo.append((lastErr,countHere))
41 lastErr = self.errList[i]
42 countHere = self.countList[i]
43 else:
44 countHere = countHere + self.countList[i]
45 i = i + 1
46
47 return histo
48
50 """ collects votes across every member of the forest for the given example
51
52 **Returns**
53
54 a list of the results
55
56 """
57 nTrees = len(self.treeList)
58 votes = [0]*nTrees
59 for i in range(nTrees):
60 votes[i] = self.treeList[i].ClassifyExample(example)
61 return votes
62
64 """ classifies the given example using the entire forest
65
66 **returns** a result and a measure of confidence in it.
67
68 **FIX:** statistics sucks... I'm not seeing an obvious way to get
69 the confidence intervals. For that matter, I'm not seeing
70 an unobvious way.
71
72 For now, this is just treated as a voting problem with the confidence
73 measure being the percent of trees which voted for the winning result.
74 """
75 self.treeVotes = self.CollectVotes(example)
76 votes = [0]*len(self._nPossible)
77 for i in range(len(self.treeList)):
78 res = self.treeVotes[i]
79 votes[res] = votes[res] + self.countList[i]
80
81 totVotes = sum(votes)
82 res = argmax(votes)
83
84 return res,float(votes[res])/float(totVotes)
85
87 """ Returns the details of the last vote the forest conducted
88
89 this will be an empty list if no voting has yet been done
90
91 """
92 return self.treeVotes
93
94 - def Grow(self,examples,attrs,nPossibleVals,nTries=10,pruneIt=0,
95 lessGreedy=0):
96 """ Grows the forest by adding trees
97
98 **Arguments**
99
100 - examples: the examples to be used for training
101
102 - attrs: a list of the attributes to be used in training
103
104 - nPossibleVals: a list with the number of possible values each variable
105 (as well as the result) can take on
106
107 - nTries: the number of new trees to add
108
109 - pruneIt: a toggle for whether or not the tree should be pruned
110
111 - lessGreedy: toggles the use of a less greedy construction algorithm where
112 each possible tree root is used. The best tree from each step is actually
113 added to the forest.
114
115 """
116 self._nPossible = nPossibleVals
117 for i in range(nTries):
118 tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals,
119 silent=1,calcTotalError=1,
120 lessGreedy=lessGreedy)
121 if pruneIt:
122 tree,frac2 = PruneTree.PruneTree(tree,tree.GetTrainingExamples(),
123 tree.GetTestExamples(),
124 minimizeTestErrorOnly=0)
125 print 'prune: ', frac,frac2
126 frac = frac2
127 self.AddTree(tree,frac)
128 if i % (nTries/10) == 0:
129 print 'Cycle: % 4d'%(i)
130
131 - def Pickle(self,fileName='foo.pkl'):
132 """ Writes this forest off to a file so that it can be easily loaded later
133
134 **Arguments**
135
136 fileName is the name of the file to be written
137
138 """
139 pFile = open(fileName,'wb+')
140 cPickle.dump(self,pFile,1)
141 pFile.close()
142
144 """ Adds a tree to the forest
145
146 If an identical tree is already present, its count is incremented
147
148 **Arguments**
149
150 - tree: the new tree
151
152 - error: its error value
153
154 **NOTE:** the errList is run as an accumulator,
155 you probably want to call AverageErrors after finishing the forest
156
157 """
158 if tree in self.treeList:
159 idx = self.treeList.index(tree)
160 self.errList[idx] = self.errList[idx]+error
161 self.countList[idx] = self.countList[idx] + 1
162 else:
163 self.treeList.append(tree)
164 self.errList.append(error)
165 self.countList.append(1)
166
168 """ convert summed error to average error
169
170 This does the conversion in place
171 """
172 self.errList = [x/y for x,y in zip(self.errList,self.countList)]
173
175 """ sorts the list of trees
176
177 **Arguments**
178
179 sortOnError: toggles sorting on the trees' errors rather than their counts
180
181 """
182 if sortOnError:
183 order = numpy.argsort(self.errList)
184 else:
185 order = numpy.argsort(self.countList)
186
187
188
189 self.treeList = [self.treeList[x] for x in order]
190 self.countList = [self.countList[x] for x in order]
191 self.errList = [self.errList[x] for x in order]
192
194 return self.treeList[i]
196 self.treeList[i] = val
197
199 return self.countList[i]
201 self.countList[i] = val
202
204 return self.errList[i]
206 self.errList[i] = val
207
209 """ returns all relevant data about a particular tree in the forest
210
211 **Arguments**
212
213 i: an integer indicating which tree should be returned
214
215 **Returns**
216
217 a 3-tuple consisting of:
218
219 1) the tree
220
221 2) its count
222
223 3) its error
224 """
225 return (self.treeList[i],self.countList[i],self.errList[i])
226
228 """ sets all relevant data for a particular tree in the forest
229
230 **Arguments**
231
232 - i: an integer indicating which tree should be returned
233
234 - tup: a 3-tuple consisting of:
235
236 1) the tree
237
238 2) its count
239
240 3) its error
241 """
242 self.treeList[i],self.countList[i],self.errList[i] = tup
243
245 """ Returns everything we know
246
247 **Returns**
248
249 a 3-tuple consisting of:
250
251 1) our list of trees
252
253 2) our list of tree counts
254
255 3) our list of tree errors
256
257 """
258 return (self.treeList,self.countList,self.errList)
259
261 """ allows len(forest) to work
262
263 """
264 return len(self.treeList)
265
267 """ allows forest[i] to work. return the data tuple
268
269 """
270 return self.GetDataTuple(which)
271
273 """ allows the forest to show itself as a string
274
275 """
276 outStr= 'Forest\n'
277 for i in xrange(len(self.treeList)):
278 outStr = outStr + \
279 ' Tree % 4d: % 5d occurances %%% 5.2f average error\n'%(i,self.countList[i],
280 100.*self.errList[i])
281 return outStr
282
284 self.treeList=[]
285 self.errList=[]
286 self.countList=[]
287 self.treeVotes=[]
288
289 if __name__ == '__main__':
290 from ML.DecTree import DecTree
291 f = Forest()
292 n = DecTree.DecTreeNode(None,'foo')
293 f.AddTree(n,0.5)
294 f.AddTree(n,0.5)
295 f.AverageErrors()
296 f.SortTrees()
297 print f
298