1
2
3
4
5
6
7
8 """
9
10 """
11
12 import numpy
13 import random
14 from ML.DecTree import QuantTree, ID3
15 from ML.InfoTheory import entropy
16 from ML.Data import Quantize
17
18 -def FindBest(resCodes,examples,nBoundsPerVar,nPossibleRes,
19 nPossibleVals,attrs,**kwargs):
20 bestGain =-1e6
21 best = -1
22 bestBounds = []
23
24 if not len(examples):
25 return best,bestGain,bestBounds
26
27 nToTake = kwargs.get('randomDescriptors',0)
28 if nToTake > 0:
29 nAttrs = len(attrs)
30 if nToTake < nAttrs:
31 ids = range(nAttrs)
32 random.shuffle(ids)
33 tmp = [attrs[x] for x in ids[:nToTake]]
34
35 attrs = tmp
36
37 for var in attrs:
38 nBounds = nBoundsPerVar[var]
39 if nBounds > 0:
40
41 try:
42 vTable = [x[var] for x in examples]
43 except IndexError:
44 print 'index error retrieving variable: %d'%var
45 raise
46 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBounds,
47 resCodes,nPossibleRes)
48
49 elif nBounds==0:
50 vTable = ID3.GenVarTable(examples,nPossibleVals,[var])[0]
51 gainHere = entropy.InfoGain(vTable)
52 qBounds = []
53 else:
54 gainHere = -1e6
55 qBounds = []
56 if gainHere > bestGain:
57 bestGain = gainHere
58 bestBounds = qBounds
59 best = var
60 if best == -1:
61 print 'best unaltered'
62 print '\tattrs:',attrs
63 print '\tnBounds:',take(nBoundsPerVar,attrs)
64 print '\texamples:'
65 for example in examples:
66 print '\t\t',example
67
68 return best,bestGain,bestBounds
69
70
71 -def BuildQuantTree(examples,target,attrs,nPossibleVals,nBoundsPerVar,
72 depth=0,maxDepth=-1,**kwargs):
73 """
74 **Arguments**
75
76 - examples: a list of lists (nInstances x nVariables+1) of variable
77 values + instance values
78
79 - target: an int
80
81 - attrs: a list of ints indicating which variables can be used in the tree
82
83 - nPossibleVals: a list containing the number of possible values of
84 every variable.
85
86 - nBoundsPerVar: the number of bounds to include for each variable
87
88 - depth: (optional) the current depth in the tree
89
90 - maxDepth: (optional) the maximum depth to which the tree
91 will be grown
92 **Returns**
93
94 a QuantTree.QuantTreeNode with the decision tree
95
96 **NOTE:** This code cannot bootstrap (start from nothing...)
97 use _QuantTreeBoot_ (below) for that.
98 """
99 tree=QuantTree.QuantTreeNode(None,'node')
100 tree.SetData(-666)
101 nPossibleRes = nPossibleVals[-1]
102
103
104 resCodes = [int(x[-1]) for x in examples]
105 counts = [0]*nPossibleRes
106 for res in resCodes:
107 counts[res] += 1
108 nzCounts = numpy.nonzero(counts)[0]
109
110 if len(nzCounts) == 1:
111
112
113
114 res = nzCounts[0]
115 tree.SetLabel(res)
116 tree.SetName(str(res))
117 tree.SetTerminal(1)
118 elif len(attrs) == 0 or (maxDepth>=0 and depth>maxDepth):
119
120
121
122
123 v = numpy.argmax(counts)
124 tree.SetLabel(v)
125 tree.SetName('%d?'%v)
126 tree.SetTerminal(1)
127 else:
128
129 best,bestGain,bestBounds = FindBest(resCodes,examples,nBoundsPerVar,
130 nPossibleRes,nPossibleVals,attrs,
131 **kwargs)
132
133
134 nextAttrs = attrs[:]
135 if not kwargs.get('recycleVars',0):
136 nextAttrs.remove(best)
137
138
139 tree.SetName('Var: %d'%(best))
140 tree.SetLabel(best)
141 tree.SetQuantBounds(bestBounds)
142 tree.SetTerminal(0)
143
144
145
146 indices = range(len(examples))
147 if len(bestBounds) > 0:
148 for bound in bestBounds:
149 nextExamples = []
150 for index in indices[:]:
151 ex = examples[index]
152 if ex[best] < bound:
153 nextExamples.append(ex)
154 indices.remove(index)
155
156 if len(nextExamples) == 0:
157
158
159
160 v = numpy.argmax(counts)
161 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
162 else:
163
164 tree.AddChildNode(BuildQuantTree(nextExamples,best,
165 nextAttrs,nPossibleVals,
166 nBoundsPerVar,
167 depth=depth+1,maxDepth=maxDepth,
168 **kwargs))
169
170 nextExamples = []
171 for index in indices:
172 nextExamples.append(examples[index])
173 if len(nextExamples) == 0:
174 v = numpy.argmax(counts)
175 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
176 else:
177 tree.AddChildNode(BuildQuantTree(nextExamples,best,
178 nextAttrs,nPossibleVals,
179 nBoundsPerVar,
180 depth=depth+1,maxDepth=maxDepth,
181 **kwargs))
182 else:
183 for val in xrange(nPossibleVals[best]):
184 nextExamples = []
185 for example in examples:
186 if example[best] == val:
187 nextExamples.append(example)
188 if len(nextExamples) == 0:
189 v = numpy.argmax(counts)
190 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
191 else:
192 tree.AddChildNode(BuildQuantTree(nextExamples,best,
193 nextAttrs,nPossibleVals,
194 nBoundsPerVar,
195 depth=depth+1,maxDepth=maxDepth,
196 **kwargs))
197 return tree
198
199 -def QuantTreeBoot(examples,attrs,nPossibleVals,nBoundsPerVar,initialVar=None,
200 maxDepth=-1,**kwargs):
201 """ Bootstrapping code for the QuantTree
202
203 If _initialVar_ is not set, the algorithm will automatically
204 choose the first variable in the tree (the standard greedy
205 approach). Otherwise, _initialVar_ will be used as the first
206 split.
207
208 """
209 attrs = attrs[:]
210 for i in range(len(nBoundsPerVar)):
211 if nBoundsPerVar[i]==-1 and i in attrs:
212 attrs.remove(i)
213
214 tree=QuantTree.QuantTreeNode(None,'node')
215 nPossibleRes = nPossibleVals[-1]
216 tree._nResultCodes = nPossibleRes
217
218 resCodes = [int(x[-1]) for x in examples]
219 counts = [0]*nPossibleRes
220 for res in resCodes:
221 counts[res] += 1
222 if initialVar is None:
223 best,gainHere,qBounds = FindBest(resCodes,examples,nBoundsPerVar,
224 nPossibleRes,nPossibleVals,attrs,
225 **kwargs)
226 else:
227 best = initialVar
228 if nBoundsPerVar[best] > 0:
229 vTable = map(lambda x,z=best:x[z],examples)
230 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBoundsPerVar[best],
231 resCodes,nPossibleRes)
232 elif nBoundsPerVar[best] == 0:
233 vTable = ID3.GenVarTable(examples,nPossibleVals,[best])[0]
234 gainHere = entropy.InfoGain(vTable)
235 qBounds = []
236 else:
237 gainHere = -1e6
238 qBounds = []
239
240 tree.SetName('Var: %d'%(best))
241 tree.SetData(gainHere)
242 tree.SetLabel(best)
243 tree.SetTerminal(0)
244 tree.SetQuantBounds(qBounds)
245 nextAttrs = attrs[:]
246 if not kwargs.get('recycleVars',0):
247 nextAttrs.remove(best)
248
249 indices = range(len(examples))
250 if len(qBounds) > 0:
251 for bound in qBounds:
252 nextExamples = []
253 for index in indices[:]:
254 ex = examples[index]
255 if ex[best] < bound:
256 nextExamples.append(ex)
257 indices.remove(index)
258
259 if len(nextExamples):
260 tree.AddChildNode(BuildQuantTree(nextExamples,best,
261 nextAttrs,nPossibleVals,
262 nBoundsPerVar,
263 depth=1,maxDepth=maxDepth,
264 **kwargs))
265 else:
266 v = numpy.argmax(counts)
267 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
268
269 nextExamples = []
270 for index in indices:
271 nextExamples.append(examples[index])
272 if len(nextExamples) != 0:
273 tree.AddChildNode(BuildQuantTree(nextExamples,best,
274 nextAttrs,nPossibleVals,
275 nBoundsPerVar,
276 depth=1,maxDepth=maxDepth,
277 **kwargs))
278 else:
279 v = numpy.argmax(counts)
280 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
281 else:
282 for val in xrange(nPossibleVals[best]):
283 nextExamples = []
284 for example in examples:
285 if example[best] == val:
286 nextExamples.append(example)
287 if len(nextExamples) != 0:
288 tree.AddChildNode(BuildQuantTree(nextExamples,best,
289 nextAttrs,nPossibleVals,
290 nBoundsPerVar,
291 depth=1,maxDepth=maxDepth,
292 **kwargs))
293 else:
294 v = numpy.argmax(counts)
295 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1)
296 return tree
297
298
300 """ testing code for named trees
301
302 """
303 examples1 = [['p1',0,1,0,0],
304 ['p2',0,0,0,1],
305 ['p3',0,0,1,2],
306 ['p4',0,1,1,2],
307 ['p5',1,0,0,2],
308 ['p6',1,0,1,2],
309 ['p7',1,1,0,2],
310 ['p8',1,1,1,0]
311 ]
312 attrs = range(1,len(examples1[0])-1)
313 nPossibleVals = [0,2,2,2,3]
314 t1 = ID3.ID3Boot(examples1,attrs,nPossibleVals,maxDepth=1)
315 t1.Print()
316
317
319 """ testing code for named trees
320
321 """
322 examples1 = [['p1',0,1,0.1,0],
323 ['p2',0,0,0.1,1],
324 ['p3',0,0,1.1,2],
325 ['p4',0,1,1.1,2],
326 ['p5',1,0,0.1,2],
327 ['p6',1,0,1.1,2],
328 ['p7',1,1,0.1,2],
329 ['p8',1,1,1.1,0]
330 ]
331 attrs = range(1,len(examples1[0])-1)
332 nPossibleVals = [0,2,2,0,3]
333 boundsPerVar=[0,0,0,1,0]
334
335 print 'base'
336 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar)
337 t1.Pickle('regress/QuantTree1.pkl')
338 t1.Print()
339
340 print 'depth limit'
341 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar,maxDepth=1)
342 t1.Pickle('regress/QuantTree1.pkl')
343 t1.Print()
344
346 """ testing code for named trees
347
348 """
349 examples1 = [['p1',0.1,1,0.1,0],
350 ['p2',0.1,0,0.1,1],
351 ['p3',0.1,0,1.1,2],
352 ['p4',0.1,1,1.1,2],
353 ['p5',1.1,0,0.1,2],
354 ['p6',1.1,0,1.1,2],
355 ['p7',1.1,1,0.1,2],
356 ['p8',1.1,1,1.1,0]
357 ]
358 attrs = range(1,len(examples1[0])-1)
359 nPossibleVals = [0,0,2,0,3]
360 boundsPerVar=[0,1,0,1,0]
361
362 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar)
363 t1.Print()
364 t1.Pickle('regress/QuantTree2.pkl')
365
366 for example in examples1:
367 print example,t1.ClassifyExample(example)
368
369 if __name__ == "__main__":
370 TestTree()
371 TestQuantTree()
372
373