Package rdkit :: Package ML :: Package DecTree :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.CrossValidate

  1  # 
  2  #  Copyright (C) 2000  greg Landrum 
  3  # 
  4  """ handles doing cross validation with decision trees 
  5   
  6  This is, perhaps, a little misleading.  For the purposes of this module, 
  7  cross validation == evaluating the accuracy of a tree. 
  8   
  9   
 10  """ 
 11  from __future__ import print_function 
 12   
 13  import numpy 
 14   
 15  from rdkit.ML.Data import SplitData 
 16  from rdkit.ML.DecTree import ID3 
 17  from rdkit.ML.DecTree import randomtest 
 18   
 19   
20 -def ChooseOptimalRoot(examples, trainExamples, testExamples, attrs, nPossibleVals, treeBuilder, 21 nQuantBounds=[], **kwargs):
22 """ loops through all possible tree roots and chooses the one which produces the best tree 23 24 **Arguments** 25 26 - examples: the full set of examples 27 28 - trainExamples: the training examples 29 30 - testExamples: the testing examples 31 32 - attrs: a list of attributes to consider in the tree building 33 34 - nPossibleVals: a list of the number of possible values each variable can adopt 35 36 - treeBuilder: the function to be used to actually build the tree 37 38 - nQuantBounds: an optional list. If present, it's assumed that the builder 39 algorithm takes this argument as well (for building QuantTrees) 40 41 **Returns** 42 43 The best tree found 44 45 **Notes** 46 47 1) Trees are built using _trainExamples_ 48 49 2) Testing of each tree (to determine which is best) is done using _CrossValidate_ and 50 the entire set of data (i.e. all of _examples_) 51 52 3) _trainExamples_ is not used at all, which immediately raises the question of 53 why it's even being passed in 54 55 """ 56 attrs = attrs[:] 57 if nQuantBounds: 58 for i in range(len(nQuantBounds)): 59 if nQuantBounds[i] == -1 and i in attrs: 60 attrs.remove(i) 61 nAttrs = len(attrs) 62 trees = [None] * nAttrs 63 errs = [0] * nAttrs 64 errs[0] = 1e6 65 66 for i in range(1, nAttrs): 67 argD = {'initialVar': attrs[i]} 68 argD.update(kwargs) 69 if nQuantBounds is None or nQuantBounds == []: 70 trees[i] = treeBuilder(trainExamples, attrs, nPossibleVals, **argD) 71 else: 72 trees[i] = treeBuilder(trainExamples, attrs, nPossibleVals, nQuantBounds, **argD) 73 if trees[i]: 74 errs[i], _ = CrossValidate(trees[i], examples, appendExamples=0) 75 else: 76 errs[i] = 1e6 77 best = numpy.argmin(errs) 78 # FIX: this used to say 'trees[i]', could that possibly have been right? 79 return trees[best]
80 81
82 -def CrossValidate(tree, testExamples, appendExamples=0):
83 """ Determines the classification error for the testExamples 84 85 **Arguments** 86 87 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method) 88 89 - testExamples: a list of examples to be used for testing 90 91 - appendExamples: a toggle which is passed along to the tree as it does 92 the classification. The trees can use this to store the examples they 93 classify locally. 94 95 **Returns** 96 97 a 2-tuple consisting of: 98 99 1) the percent error of the tree 100 101 2) a list of misclassified examples 102 103 """ 104 nTest = len(testExamples) 105 nBad = 0 106 badExamples = [] 107 for i in range(nTest): 108 testEx = testExamples[i] 109 trueRes = testEx[-1] 110 res = tree.ClassifyExample(testEx, appendExamples) 111 if (trueRes != res).any(): 112 badExamples.append(testEx) 113 nBad += 1 114 115 return float(nBad) / nTest, badExamples
116 117
118 -def CrossValidationDriver(examples, attrs, nPossibleVals, holdOutFrac=.3, silent=0, 119 calcTotalError=0, treeBuilder=ID3.ID3Boot, lessGreedy=0, startAt=None, 120 nQuantBounds=[], maxDepth=-1, **kwargs):
121 """ Driver function for building trees and doing cross validation 122 123 **Arguments** 124 125 - examples: the full set of examples 126 127 - attrs: a list of attributes to consider in the tree building 128 129 - nPossibleVals: a list of the number of possible values each variable can adopt 130 131 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set 132 (used to calculate the error) 133 134 - silent: a toggle used to control how much visual noise this makes as it goes. 135 136 - calcTotalError: a toggle used to indicate whether the classification error 137 of the tree should be calculated using the entire data set (when true) or just 138 the training hold out set (when false) 139 140 - treeBuilder: the function to call to build the tree 141 142 - lessGreedy: toggles use of the less greedy tree growth algorithm (see 143 _ChooseOptimalRoot_). 144 145 - startAt: forces the tree to be rooted at this descriptor 146 147 - nQuantBounds: an optional list. If present, it's assumed that the builder 148 algorithm takes this argument as well (for building QuantTrees) 149 150 - maxDepth: an optional integer. If present, it's assumed that the builder 151 algorithm takes this argument as well 152 153 **Returns** 154 155 a 2-tuple containing: 156 157 1) the tree 158 159 2) the cross-validation error of the tree 160 161 """ 162 nTot = len(examples) 163 if not kwargs.get('replacementSelection', 0): 164 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=1, 165 replacement=0) 166 else: 167 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=0, 168 replacement=1) 169 trainExamples = [examples[x] for x in trainIndices] 170 testExamples = [examples[x] for x in testIndices] 171 172 nTrain = len(trainExamples) 173 if not silent: 174 print('Training with %d examples' % (nTrain)) 175 176 if not lessGreedy: 177 if nQuantBounds is None or nQuantBounds == []: 178 tree = treeBuilder(trainExamples, attrs, nPossibleVals, initialVar=startAt, maxDepth=maxDepth, 179 **kwargs) 180 else: 181 tree = treeBuilder(trainExamples, attrs, nPossibleVals, nQuantBounds, initialVar=startAt, 182 maxDepth=maxDepth, **kwargs) 183 else: 184 tree = ChooseOptimalRoot(examples, trainExamples, testExamples, attrs, nPossibleVals, 185 treeBuilder, nQuantBounds, maxDepth=maxDepth, **kwargs) 186 187 nTest = len(testExamples) 188 if not silent: 189 print('Testing with %d examples' % nTest) 190 if not calcTotalError: 191 xValError, badExamples = CrossValidate(tree, testExamples, appendExamples=1) 192 else: 193 xValError, badExamples = CrossValidate(tree, examples, appendExamples=0) 194 if not silent: 195 print('Validation error was %%%4.2f' % (100 * xValError)) 196 tree.SetBadExamples(badExamples) 197 tree.SetTrainingExamples(trainExamples) 198 tree.SetTestExamples(testExamples) 199 tree._trainIndices = trainIndices 200 return tree, xValError
201 202
203 -def TestRun():
204 """ testing code """ 205 examples, attrs, nPossibleVals = randomtest.GenRandomExamples(nExamples=200) 206 tree, _ = CrossValidationDriver(examples, attrs, nPossibleVals) 207 208 tree.Pickle('save.pkl') 209 210 import copy 211 t2 = copy.deepcopy(tree) 212 print('t1 == t2', tree == t2) 213 l = [tree] 214 print('t2 in [tree]', t2 in l, l.index(t2))
215 216 217 if __name__ == '__main__': # pragma: nocover 218 TestRun() 219