Package ML :: Package DecTree :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module ML.DecTree.CrossValidate

  1  # 
  2  #  Copyright (C) 2000  greg Landrum 
  3  # 
  4  """ handles doing cross validation with decision trees 
  5   
  6  This is, perhaps, a little misleading.  For the purposes of this module, 
  7  cross validation == evaluating the accuracy of a tree. 
  8   
  9   
 10  """ 
 11  from ML.DecTree import ID3 
 12  from ML.Data import SplitData 
 13  import numpy 
 14  import RDRandom 
 15   
16 -def ChooseOptimalRoot(examples,trainExamples,testExamples,attrs, 17 nPossibleVals,treeBuilder,nQuantBounds=[], 18 **kwargs):
19 """ loops through all possible tree roots and chooses the one which produces the best tree 20 21 **Arguments** 22 23 - examples: the full set of examples 24 25 - trainExamples: the training examples 26 27 - testExamples: the testing examples 28 29 - attrs: a list of attributes to consider in the tree building 30 31 - nPossibleVals: a list of the number of possible values each variable can adopt 32 33 - treeBuilder: the function to be used to actually build the tree 34 35 - nQuantBounds: an optional list. If present, it's assumed that the builder 36 algorithm takes this argument as well (for building QuantTrees) 37 38 **Returns** 39 40 The best tree found 41 42 **Notes** 43 44 1) Trees are built using _trainExamples_ 45 46 2) Testing of each tree (to determine which is best) is done using _CrossValidate_ and 47 the entire set of data (i.e. all of _examples_) 48 49 3) _trainExamples_ is not used at all, which immediately raises the question of 50 why it's even being passed in 51 52 """ 53 attrs = attrs[:] 54 if nQuantBounds: 55 for i in range(len(nQuantBounds)): 56 if nQuantBounds[i]==-1 and i in attrs: 57 attrs.remove(i) 58 nAttrs = len(attrs) 59 trees = [None]*nAttrs 60 errs = [0]*nAttrs 61 errs[0] = 1e6 62 63 for i in xrange(1,nAttrs): 64 argD = {'initialVar':attrs[i]} 65 argD.update(kwargs) 66 if nQuantBounds is None or nQuantBounds == []: 67 trees[i] = apply(treeBuilder,(trainExamples,attrs,nPossibleVals), 68 argD) 69 else: 70 trees[i] = apply(treeBuilder,(trainExamples,attrs,nPossibleVals,nQuantBounds), 71 argD) 72 if trees[i]: 73 errs[i],foo = CrossValidate(trees[i],examples,appendExamples=0) 74 else: 75 errs[i] = 1e6 76 best = numpy.argmin(errs) 77 # FIX: this used to say 'trees[i]', could that possibly have been right? 78 return trees[best]
79
80 -def CrossValidate(tree,testExamples,appendExamples=0):
81 """ Determines the classification error for the testExamples 82 83 **Arguments** 84 85 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method) 86 87 - testExamples: a list of examples to be used for testing 88 89 - appendExamples: a toggle which is passed along to the tree as it does 90 the classification. The trees can use this to store the examples they 91 classify locally. 92 93 **Returns** 94 95 a 2-tuple consisting of: 96 97 1) the percent error of the tree 98 99 2) a list of misclassified examples 100 101 """ 102 nTest = len(testExamples) 103 nBad = 0 104 badExamples = [] 105 for i in xrange(nTest): 106 testEx = testExamples[i] 107 trueRes = testEx[-1] 108 res = tree.ClassifyExample(testEx,appendExamples) 109 if (trueRes != res).any(): 110 badExamples.append(testEx) 111 nBad += 1 112 113 114 return float(nBad)/nTest,badExamples
115
116 -def CrossValidationDriver(examples,attrs,nPossibleVals,holdOutFrac=.3,silent=0, 117 calcTotalError=0,treeBuilder=ID3.ID3Boot,lessGreedy=0, 118 startAt=None, 119 nQuantBounds=[], 120 maxDepth=-1, 121 **kwargs):
122 """ Driver function for building trees and doing cross validation 123 124 **Arguments** 125 126 - examples: the full set of examples 127 128 - attrs: a list of attributes to consider in the tree building 129 130 - nPossibleVals: a list of the number of possible values each variable can adopt 131 132 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set 133 (used to calculate the error) 134 135 - silent: a toggle used to control how much visual noise this makes as it goes. 136 137 - calcTotalError: a toggle used to indicate whether the classification error 138 of the tree should be calculated using the entire data set (when true) or just 139 the training hold out set (when false) 140 141 - treeBuilder: the function to call to build the tree 142 143 - lessGreedy: toggles use of the less greedy tree growth algorithm (see 144 _ChooseOptimalRoot_). 145 146 - startAt: forces the tree to be rooted at this descriptor 147 148 - nQuantBounds: an optional list. If present, it's assumed that the builder 149 algorithm takes this argument as well (for building QuantTrees) 150 151 - maxDepth: an optional integer. If present, it's assumed that the builder 152 algorithm takes this argument as well 153 154 **Returns** 155 156 a 2-tuple containing: 157 158 1) the tree 159 160 2) the cross-validation error of the tree 161 162 """ 163 nTot = len(examples) 164 if not kwargs.get('replacementSelection',0): 165 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 166 silent=1,legacy=1, 167 replacement=0) 168 else: 169 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 170 silent=1,legacy=0, 171 replacement=1) 172 trainExamples = [examples[x] for x in trainIndices] 173 testExamples = [examples[x] for x in testIndices] 174 175 nTrain = len(trainExamples) 176 if not silent: 177 print 'Training with %d examples'%(nTrain) 178 179 if not lessGreedy: 180 if nQuantBounds is None or nQuantBounds == []: 181 tree = treeBuilder(trainExamples,attrs,nPossibleVals, 182 initialVar=startAt,maxDepth=maxDepth,**kwargs) 183 else: 184 tree = treeBuilder(trainExamples,attrs,nPossibleVals,nQuantBounds, 185 initialVar=startAt,maxDepth=maxDepth,**kwargs) 186 else: 187 tree = ChooseOptimalRoot(examples,trainExamples,testExamples, 188 attrs,nPossibleVals,treeBuilder,nQuantBounds, 189 maxDepth=maxDepth,**kwargs) 190 191 nTest = len(testExamples) 192 if not silent: 193 print 'Testing with %d examples'%nTest 194 if not calcTotalError: 195 xValError,badExamples = CrossValidate(tree,testExamples,appendExamples=1) 196 else: 197 xValError,badExamples = CrossValidate(tree,examples,appendExamples=0) 198 if not silent: 199 print 'Validation error was %%%4.2f'%(100*xValError) 200 tree.SetBadExamples(badExamples) 201 tree.SetTrainingExamples(trainExamples) 202 tree.SetTestExamples(testExamples) 203 tree._trainIndices = trainIndices 204 return tree,xValError
205 206
207 -def TestRun():
208 """ testing code 209 210 """ 211 from ML.DecTree import randomtest 212 examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nExamples = 200) 213 tree,frac = CrossValidationDriver(examples,attrs, 214 nPossibleVals) 215 216 tree.Pickle('save.pkl') 217 218 import copy 219 t2 = copy.deepcopy(tree) 220 print 't1 == t2',tree==t2 221 l = [tree] 222 print 't2 in [tree]', t2 in l, l.index(t2)
223 224 if __name__ == '__main__': 225 TestRun() 226