Package ML :: Package DecTree :: Module BuildQuantTree
[hide private]
[frames] | no frames]

Source Code for Module ML.DecTree.BuildQuantTree

  1  ## Automatically adapted for numpy.oldnumeric Jun 27, 2008 by -c 
  2   
  3  # $Id: BuildQuantTree.py 742 2008-07-05 07:42:38Z glandrum $ 
  4  # 
  5  #  Copyright (C) 2001-2008  greg Landrum and Rational Discovery LLC 
  6  #  All Rights Reserved 
  7  # 
  8  """  
  9   
 10  """ 
 11   
 12  import numpy 
 13  import random 
 14  from ML.DecTree import QuantTree, ID3 
 15  from ML.InfoTheory import entropy 
 16  from ML.Data import Quantize 
 17   
18 -def FindBest(resCodes,examples,nBoundsPerVar,nPossibleRes, 19 nPossibleVals,attrs,**kwargs):
20 bestGain =-1e6 21 best = -1 22 bestBounds = [] 23 24 if not len(examples): 25 return best,bestGain,bestBounds 26 27 nToTake = kwargs.get('randomDescriptors',0) 28 if nToTake > 0: 29 nAttrs = len(attrs) 30 if nToTake < nAttrs: 31 ids = range(nAttrs) 32 random.shuffle(ids) 33 tmp = [attrs[x] for x in ids[:nToTake]] 34 #print '\tavail:',tmp 35 attrs = tmp 36 37 for var in attrs: 38 nBounds = nBoundsPerVar[var] 39 if nBounds > 0: 40 #vTable = map(lambda x,z=var:x[z],examples) 41 try: 42 vTable = [x[var] for x in examples] 43 except IndexError: 44 print 'index error retrieving variable: %d'%var 45 raise 46 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBounds, 47 resCodes,nPossibleRes) 48 #print '\tvar:',var,qBounds,gainHere 49 elif nBounds==0: 50 vTable = ID3.GenVarTable(examples,nPossibleVals,[var])[0] 51 gainHere = entropy.InfoGain(vTable) 52 qBounds = [] 53 else: 54 gainHere = -1e6 55 qBounds = [] 56 if gainHere > bestGain: 57 bestGain = gainHere 58 bestBounds = qBounds 59 best = var 60 if best == -1: 61 print 'best unaltered' 62 print '\tattrs:',attrs 63 print '\tnBounds:',take(nBoundsPerVar,attrs) 64 print '\texamples:' 65 for example in examples: 66 print '\t\t',example 67 68 return best,bestGain,bestBounds
69 70
71 -def BuildQuantTree(examples,target,attrs,nPossibleVals,nBoundsPerVar, 72 depth=0,maxDepth=-1,**kwargs):
73 """ 74 **Arguments** 75 76 - examples: a list of lists (nInstances x nVariables+1) of variable 77 values + instance values 78 79 - target: an int 80 81 - attrs: a list of ints indicating which variables can be used in the tree 82 83 - nPossibleVals: a list containing the number of possible values of 84 every variable. 85 86 - nBoundsPerVar: the number of bounds to include for each variable 87 88 - depth: (optional) the current depth in the tree 89 90 - maxDepth: (optional) the maximum depth to which the tree 91 will be grown 92 **Returns** 93 94 a QuantTree.QuantTreeNode with the decision tree 95 96 **NOTE:** This code cannot bootstrap (start from nothing...) 97 use _QuantTreeBoot_ (below) for that. 98 """ 99 tree=QuantTree.QuantTreeNode(None,'node') 100 tree.SetData(-666) 101 nPossibleRes = nPossibleVals[-1] 102 103 # counts of each result code: 104 resCodes = [int(x[-1]) for x in examples] 105 counts = [0]*nPossibleRes 106 for res in resCodes: 107 counts[res] += 1 108 nzCounts = numpy.nonzero(counts)[0] 109 110 if len(nzCounts) == 1: 111 # bottomed out because there is only one result code left 112 # with any counts (i.e. there's only one type of example 113 # left... this is GOOD!). 114 res = nzCounts[0] 115 tree.SetLabel(res) 116 tree.SetName(str(res)) 117 tree.SetTerminal(1) 118 elif len(attrs) == 0 or (maxDepth>=0 and depth>maxDepth): 119 # Bottomed out: no variables left or max depth hit 120 # We don't really know what to do here, so 121 # use the heuristic of picking the most prevalent 122 # result 123 v = numpy.argmax(counts) 124 tree.SetLabel(v) 125 tree.SetName('%d?'%v) 126 tree.SetTerminal(1) 127 else: 128 # find the variable which gives us the largest information gain 129 best,bestGain,bestBounds = FindBest(resCodes,examples,nBoundsPerVar, 130 nPossibleRes,nPossibleVals,attrs, 131 **kwargs) 132 133 # remove that variable from the lists of possible variables 134 nextAttrs = attrs[:] 135 if not kwargs.get('recycleVars',0): 136 nextAttrs.remove(best) 137 138 # set some info at this node 139 tree.SetName('Var: %d'%(best)) 140 tree.SetLabel(best) 141 tree.SetQuantBounds(bestBounds) 142 tree.SetTerminal(0) 143 144 # loop over possible values of the new variable and 145 # build a subtree for each one 146 indices = range(len(examples)) 147 if len(bestBounds) > 0: 148 for bound in bestBounds: 149 nextExamples = [] 150 for index in indices[:]: 151 ex = examples[index] 152 if ex[best] < bound: 153 nextExamples.append(ex) 154 indices.remove(index) 155 156 if len(nextExamples) == 0: 157 # this particular value of the variable has no examples, 158 # so there's not much sense in recursing. 159 # This can (and does) happen. 160 v = numpy.argmax(counts) 161 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 162 else: 163 # recurse 164 tree.AddChildNode(BuildQuantTree(nextExamples,best, 165 nextAttrs,nPossibleVals, 166 nBoundsPerVar, 167 depth=depth+1,maxDepth=maxDepth, 168 **kwargs)) 169 # add the last points remaining 170 nextExamples = [] 171 for index in indices: 172 nextExamples.append(examples[index]) 173 if len(nextExamples) == 0: 174 v = numpy.argmax(counts) 175 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 176 else: 177 tree.AddChildNode(BuildQuantTree(nextExamples,best, 178 nextAttrs,nPossibleVals, 179 nBoundsPerVar, 180 depth=depth+1,maxDepth=maxDepth, 181 **kwargs)) 182 else: 183 for val in xrange(nPossibleVals[best]): 184 nextExamples = [] 185 for example in examples: 186 if example[best] == val: 187 nextExamples.append(example) 188 if len(nextExamples) == 0: 189 v = numpy.argmax(counts) 190 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 191 else: 192 tree.AddChildNode(BuildQuantTree(nextExamples,best, 193 nextAttrs,nPossibleVals, 194 nBoundsPerVar, 195 depth=depth+1,maxDepth=maxDepth, 196 **kwargs)) 197 return tree
198
199 -def QuantTreeBoot(examples,attrs,nPossibleVals,nBoundsPerVar,initialVar=None, 200 maxDepth=-1,**kwargs):
201 """ Bootstrapping code for the QuantTree 202 203 If _initialVar_ is not set, the algorithm will automatically 204 choose the first variable in the tree (the standard greedy 205 approach). Otherwise, _initialVar_ will be used as the first 206 split. 207 208 """ 209 attrs = attrs[:] 210 for i in range(len(nBoundsPerVar)): 211 if nBoundsPerVar[i]==-1 and i in attrs: 212 attrs.remove(i) 213 214 tree=QuantTree.QuantTreeNode(None,'node') 215 nPossibleRes = nPossibleVals[-1] 216 tree._nResultCodes = nPossibleRes 217 218 resCodes = [int(x[-1]) for x in examples] 219 counts = [0]*nPossibleRes 220 for res in resCodes: 221 counts[res] += 1 222 if initialVar is None: 223 best,gainHere,qBounds = FindBest(resCodes,examples,nBoundsPerVar, 224 nPossibleRes,nPossibleVals,attrs, 225 **kwargs) 226 else: 227 best = initialVar 228 if nBoundsPerVar[best] > 0: 229 vTable = map(lambda x,z=best:x[z],examples) 230 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBoundsPerVar[best], 231 resCodes,nPossibleRes) 232 elif nBoundsPerVar[best] == 0: 233 vTable = ID3.GenVarTable(examples,nPossibleVals,[best])[0] 234 gainHere = entropy.InfoGain(vTable) 235 qBounds = [] 236 else: 237 gainHere = -1e6 238 qBounds = [] 239 240 tree.SetName('Var: %d'%(best)) 241 tree.SetData(gainHere) 242 tree.SetLabel(best) 243 tree.SetTerminal(0) 244 tree.SetQuantBounds(qBounds) 245 nextAttrs = attrs[:] 246 if not kwargs.get('recycleVars',0): 247 nextAttrs.remove(best) 248 249 indices = range(len(examples)) 250 if len(qBounds) > 0: 251 for bound in qBounds: 252 nextExamples = [] 253 for index in indices[:]: 254 ex = examples[index] 255 if ex[best] < bound: 256 nextExamples.append(ex) 257 indices.remove(index) 258 259 if len(nextExamples): 260 tree.AddChildNode(BuildQuantTree(nextExamples,best, 261 nextAttrs,nPossibleVals, 262 nBoundsPerVar, 263 depth=1,maxDepth=maxDepth, 264 **kwargs)) 265 else: 266 v = numpy.argmax(counts) 267 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 268 # add the last points remaining 269 nextExamples = [] 270 for index in indices: 271 nextExamples.append(examples[index]) 272 if len(nextExamples) != 0: 273 tree.AddChildNode(BuildQuantTree(nextExamples,best, 274 nextAttrs,nPossibleVals, 275 nBoundsPerVar, 276 depth=1,maxDepth=maxDepth, 277 **kwargs)) 278 else: 279 v = numpy.argmax(counts) 280 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 281 else: 282 for val in xrange(nPossibleVals[best]): 283 nextExamples = [] 284 for example in examples: 285 if example[best] == val: 286 nextExamples.append(example) 287 if len(nextExamples) != 0: 288 tree.AddChildNode(BuildQuantTree(nextExamples,best, 289 nextAttrs,nPossibleVals, 290 nBoundsPerVar, 291 depth=1,maxDepth=maxDepth, 292 **kwargs)) 293 else: 294 v = numpy.argmax(counts) 295 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 296 return tree
297 298
299 -def TestTree():
300 """ testing code for named trees 301 302 """ 303 examples1 = [['p1',0,1,0,0], 304 ['p2',0,0,0,1], 305 ['p3',0,0,1,2], 306 ['p4',0,1,1,2], 307 ['p5',1,0,0,2], 308 ['p6',1,0,1,2], 309 ['p7',1,1,0,2], 310 ['p8',1,1,1,0] 311 ] 312 attrs = range(1,len(examples1[0])-1) 313 nPossibleVals = [0,2,2,2,3] 314 t1 = ID3.ID3Boot(examples1,attrs,nPossibleVals,maxDepth=1) 315 t1.Print()
316 317
318 -def TestQuantTree():
319 """ testing code for named trees 320 321 """ 322 examples1 = [['p1',0,1,0.1,0], 323 ['p2',0,0,0.1,1], 324 ['p3',0,0,1.1,2], 325 ['p4',0,1,1.1,2], 326 ['p5',1,0,0.1,2], 327 ['p6',1,0,1.1,2], 328 ['p7',1,1,0.1,2], 329 ['p8',1,1,1.1,0] 330 ] 331 attrs = range(1,len(examples1[0])-1) 332 nPossibleVals = [0,2,2,0,3] 333 boundsPerVar=[0,0,0,1,0] 334 335 print 'base' 336 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar) 337 t1.Pickle('regress/QuantTree1.pkl') 338 t1.Print() 339 340 print 'depth limit' 341 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar,maxDepth=1) 342 t1.Pickle('regress/QuantTree1.pkl') 343 t1.Print()
344
345 -def TestQuantTree2():
346 """ testing code for named trees 347 348 """ 349 examples1 = [['p1',0.1,1,0.1,0], 350 ['p2',0.1,0,0.1,1], 351 ['p3',0.1,0,1.1,2], 352 ['p4',0.1,1,1.1,2], 353 ['p5',1.1,0,0.1,2], 354 ['p6',1.1,0,1.1,2], 355 ['p7',1.1,1,0.1,2], 356 ['p8',1.1,1,1.1,0] 357 ] 358 attrs = range(1,len(examples1[0])-1) 359 nPossibleVals = [0,0,2,0,3] 360 boundsPerVar=[0,1,0,1,0] 361 362 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar) 363 t1.Print() 364 t1.Pickle('regress/QuantTree2.pkl') 365 366 for example in examples1: 367 print example,t1.ClassifyExample(example)
368 369 if __name__ == "__main__": 370 TestTree() 371 TestQuantTree() 372 #TestQuantTree2() 373