Package rdkit :: Package ML :: Package DecTree :: Module BuildQuantTree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.BuildQuantTree

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2001-2008  greg Landrum and Rational Discovery LLC 
  4  #  All Rights Reserved 
  5  # 
  6  """ 
  7   
  8  """ 
  9  from __future__ import print_function 
 10  import numpy 
 11  import random 
 12  from rdkit.ML.DecTree import QuantTree, ID3 
 13  from rdkit.ML.InfoTheory import entropy 
 14  from rdkit.ML.Data import Quantize 
 15  from rdkit.six.moves import range 
 16   
 17   
18 -def FindBest(resCodes, examples, nBoundsPerVar, nPossibleRes, nPossibleVals, attrs, exIndices=None, 19 **kwargs):
20 bestGain = -1e6 21 best = -1 22 bestBounds = [] 23 24 if exIndices is None: 25 exIndices = list(range(len(examples))) 26 27 if not len(exIndices): 28 return best, bestGain, bestBounds 29 30 nToTake = kwargs.get('randomDescriptors', 0) 31 if nToTake > 0: 32 nAttrs = len(attrs) 33 if nToTake < nAttrs: 34 ids = list(range(nAttrs)) 35 random.shuffle(ids, random=random.random) 36 tmp = [attrs[x] for x in ids[:nToTake]] 37 attrs = tmp 38 39 for var in attrs: 40 nBounds = nBoundsPerVar[var] 41 if nBounds > 0: 42 # vTable = map(lambda x,z=var:x[z],examples) 43 try: 44 vTable = [examples[x][var] for x in exIndices] 45 except IndexError: 46 print('index error retrieving variable: %d' % var) 47 raise 48 qBounds, gainHere = Quantize.FindVarMultQuantBounds(vTable, nBounds, resCodes, nPossibleRes) 49 # print('\tvar:',var,qBounds,gainHere) 50 elif nBounds == 0: 51 vTable = ID3.GenVarTable((examples[x] for x in exIndices), nPossibleVals, [var])[0] 52 gainHere = entropy.InfoGain(vTable) 53 qBounds = [] 54 else: 55 gainHere = -1e6 56 qBounds = [] 57 if gainHere > bestGain: 58 bestGain = gainHere 59 bestBounds = qBounds 60 best = var 61 elif bestGain == gainHere: 62 if len(qBounds) < len(bestBounds): 63 best = var 64 bestBounds = qBounds 65 if best == -1: 66 print('best unaltered') 67 print('\tattrs:', attrs) 68 print('\tnBounds:', numpy.take(nBoundsPerVar, attrs)) 69 print('\texamples:') 70 for example in (examples[x] for x in exIndices): 71 print('\t\t', example) 72 73 if 0: 74 print('BEST:', len(exIndices), best, bestGain, bestBounds) 75 if (len(exIndices) < 10): 76 print(len(exIndices), len(resCodes), len(examples)) 77 exs = [examples[x] for x in exIndices] 78 vals = [x[best] for x in exs] 79 sortIdx = numpy.argsort(vals) 80 sortVals = [exs[x] for x in sortIdx] 81 sortResults = [resCodes[x] for x in sortIdx] 82 for i in range(len(vals)): 83 print(' ', i, ['%.4f' % x for x in sortVals[i][1:-1]], sortResults[i]) 84 return best, bestGain, bestBounds
85 86
87 -def BuildQuantTree(examples, target, attrs, nPossibleVals, nBoundsPerVar, depth=0, maxDepth=-1, 88 exIndices=None, **kwargs):
89 """ 90 **Arguments** 91 92 - examples: a list of lists (nInstances x nVariables+1) of variable 93 values + instance values 94 95 - target: an int 96 97 - attrs: a list of ints indicating which variables can be used in the tree 98 99 - nPossibleVals: a list containing the number of possible values of 100 every variable. 101 102 - nBoundsPerVar: the number of bounds to include for each variable 103 104 - depth: (optional) the current depth in the tree 105 106 - maxDepth: (optional) the maximum depth to which the tree 107 will be grown 108 **Returns** 109 110 a QuantTree.QuantTreeNode with the decision tree 111 112 **NOTE:** This code cannot bootstrap (start from nothing...) 113 use _QuantTreeBoot_ (below) for that. 114 """ 115 tree = QuantTree.QuantTreeNode(None, 'node') 116 tree.SetData(-666) 117 nPossibleRes = nPossibleVals[-1] 118 119 if exIndices is None: 120 exIndices = list(range(len(examples))) 121 122 # counts of each result code: 123 resCodes = [int(x[-1]) for x in (examples[y] for y in exIndices)] 124 counts = [0] * nPossibleRes 125 for res in resCodes: 126 counts[res] += 1 127 nzCounts = numpy.nonzero(counts)[0] 128 129 if len(nzCounts) == 1: 130 # bottomed out because there is only one result code left 131 # with any counts (i.e. there's only one type of example 132 # left... this is GOOD!). 133 res = nzCounts[0] 134 tree.SetLabel(res) 135 tree.SetName(str(res)) 136 tree.SetTerminal(1) 137 elif len(attrs) == 0 or (maxDepth >= 0 and depth > maxDepth): 138 # Bottomed out: no variables left or max depth hit 139 # We don't really know what to do here, so 140 # use the heuristic of picking the most prevalent 141 # result 142 v = numpy.argmax(counts) 143 tree.SetLabel(v) 144 tree.SetName('%d?' % v) 145 tree.SetTerminal(1) 146 else: 147 # find the variable which gives us the largest information gain 148 best, _, bestBounds = FindBest(resCodes, examples, nBoundsPerVar, nPossibleRes, nPossibleVals, 149 attrs, exIndices=exIndices, **kwargs) 150 # remove that variable from the lists of possible variables 151 nextAttrs = attrs[:] 152 if not kwargs.get('recycleVars', 0): 153 nextAttrs.remove(best) 154 155 # set some info at this node 156 tree.SetName('Var: %d' % (best)) 157 tree.SetLabel(best) 158 tree.SetQuantBounds(bestBounds) 159 tree.SetTerminal(0) 160 161 # loop over possible values of the new variable and 162 # build a subtree for each one 163 indices = exIndices[:] 164 if len(bestBounds) > 0: 165 for bound in bestBounds: 166 nextExamples = [] 167 for index in indices[:]: 168 ex = examples[index] 169 if ex[best] < bound: 170 nextExamples.append(index) 171 indices.remove(index) 172 173 if len(nextExamples) == 0: 174 # this particular value of the variable has no examples, 175 # so there's not much sense in recursing. 176 # This can (and does) happen. 177 v = numpy.argmax(counts) 178 tree.AddChild('%d' % v, label=v, data=0.0, isTerminal=1) 179 else: 180 # recurse 181 tree.AddChildNode( 182 BuildQuantTree(examples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=depth + 1, 183 maxDepth=maxDepth, exIndices=nextExamples, **kwargs)) 184 # add the last points remaining 185 nextExamples = [] 186 for index in indices: 187 nextExamples.append(index) 188 if len(nextExamples) == 0: 189 v = numpy.argmax(counts) 190 tree.AddChild('%d' % v, label=v, data=0.0, isTerminal=1) 191 else: 192 tree.AddChildNode( 193 BuildQuantTree(examples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=depth + 1, 194 maxDepth=maxDepth, exIndices=nextExamples, **kwargs)) 195 else: 196 for val in range(nPossibleVals[best]): 197 nextExamples = [] 198 for idx in exIndices: 199 if examples[idx][best] == val: 200 nextExamples.append(idx) 201 if len(nextExamples) == 0: 202 v = numpy.argmax(counts) 203 tree.AddChild('%d' % v, label=v, data=0.0, isTerminal=1) 204 else: 205 tree.AddChildNode( 206 BuildQuantTree(examples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=depth + 1, 207 maxDepth=maxDepth, exIndices=nextExamples, **kwargs)) 208 return tree
209 210
211 -def QuantTreeBoot(examples, attrs, nPossibleVals, nBoundsPerVar, initialVar=None, maxDepth=-1, 212 **kwargs):
213 """ Bootstrapping code for the QuantTree 214 215 If _initialVar_ is not set, the algorithm will automatically 216 choose the first variable in the tree (the standard greedy 217 approach). Otherwise, _initialVar_ will be used as the first 218 split. 219 220 """ 221 attrs = list(attrs) 222 for i in range(len(nBoundsPerVar)): 223 if nBoundsPerVar[i] == -1 and i in attrs: 224 attrs.remove(i) 225 226 tree = QuantTree.QuantTreeNode(None, 'node') 227 nPossibleRes = nPossibleVals[-1] 228 tree._nResultCodes = nPossibleRes 229 230 resCodes = [int(x[-1]) for x in examples] 231 counts = [0] * nPossibleRes 232 for res in resCodes: 233 counts[res] += 1 234 if initialVar is None: 235 best, gainHere, qBounds = FindBest(resCodes, examples, nBoundsPerVar, nPossibleRes, 236 nPossibleVals, attrs, **kwargs) 237 else: 238 best = initialVar 239 if nBoundsPerVar[best] > 0: 240 vTable = map(lambda x, z=best: x[z], examples) 241 qBounds, gainHere = Quantize.FindVarMultQuantBounds(vTable, nBoundsPerVar[best], resCodes, 242 nPossibleRes) 243 elif nBoundsPerVar[best] == 0: 244 vTable = ID3.GenVarTable(examples, nPossibleVals, [best])[0] 245 gainHere = entropy.InfoGain(vTable) 246 qBounds = [] 247 else: 248 gainHere = -1e6 249 qBounds = [] 250 251 tree.SetName('Var: %d' % (best)) 252 tree.SetData(gainHere) 253 tree.SetLabel(best) 254 tree.SetTerminal(0) 255 tree.SetQuantBounds(qBounds) 256 nextAttrs = list(attrs) 257 if not kwargs.get('recycleVars', 0): 258 nextAttrs.remove(best) 259 260 indices = list(range(len(examples))) 261 if len(qBounds) > 0: 262 for bound in qBounds: 263 nextExamples = [] 264 for index in list(indices): 265 ex = examples[index] 266 if ex[best] < bound: 267 nextExamples.append(ex) 268 indices.remove(index) 269 270 if len(nextExamples): 271 tree.AddChildNode( 272 BuildQuantTree(nextExamples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=1, 273 maxDepth=maxDepth, **kwargs)) 274 else: 275 v = numpy.argmax(counts) 276 tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1) 277 # add the last points remaining 278 nextExamples = [] 279 for index in indices: 280 nextExamples.append(examples[index]) 281 if len(nextExamples) != 0: 282 tree.AddChildNode( 283 BuildQuantTree(nextExamples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=1, 284 maxDepth=maxDepth, **kwargs)) 285 else: 286 v = numpy.argmax(counts) 287 tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1) 288 else: 289 for val in range(nPossibleVals[best]): 290 nextExamples = [] 291 for example in examples: 292 if example[best] == val: 293 nextExamples.append(example) 294 if len(nextExamples) != 0: 295 tree.AddChildNode( 296 BuildQuantTree(nextExamples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=1, 297 maxDepth=maxDepth, **kwargs)) 298 else: 299 v = numpy.argmax(counts) 300 tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1) 301 return tree
302 303
304 -def TestTree():
305 """ testing code for named trees 306 307 """ 308 examples1 = [['p1', 0, 1, 0, 0], ['p2', 0, 0, 0, 1], ['p3', 0, 0, 1, 2], ['p4', 0, 1, 1, 2], 309 ['p5', 1, 0, 0, 2], ['p6', 1, 0, 1, 2], ['p7', 1, 1, 0, 2], ['p8', 1, 1, 1, 0]] 310 attrs = list(range(1, len(examples1[0]) - 1)) 311 nPossibleVals = [0, 2, 2, 2, 3] 312 t1 = ID3.ID3Boot(examples1, attrs, nPossibleVals, maxDepth=1) 313 t1.Print()
314 315
316 -def TestQuantTree(): # pragma: nocover
317 """ Testing code for named trees 318 319 The created pkl file is required by the unit test code. 320 """ 321 examples1 = [['p1', 0, 1, 0.1, 0], ['p2', 0, 0, 0.1, 1], ['p3', 0, 0, 1.1, 2], 322 ['p4', 0, 1, 1.1, 2], ['p5', 1, 0, 0.1, 2], ['p6', 1, 0, 1.1, 2], 323 ['p7', 1, 1, 0.1, 2], ['p8', 1, 1, 1.1, 0]] 324 attrs = list(range(1, len(examples1[0]) - 1)) 325 nPossibleVals = [0, 2, 2, 0, 3] 326 boundsPerVar = [0, 0, 0, 1, 0] 327 328 print('base') 329 t1 = QuantTreeBoot(examples1, attrs, nPossibleVals, boundsPerVar) 330 t1.Pickle('test_data/QuantTree1.pkl') 331 t1.Print() 332 333 print('depth limit') 334 t1 = QuantTreeBoot(examples1, attrs, nPossibleVals, boundsPerVar, maxDepth=1) 335 t1.Pickle('test_data/QuantTree1.pkl') 336 t1.Print() 337 338
339 -def TestQuantTree2(): # pragma: nocover
340 """ testing code for named trees 341 342 The created pkl file is required by the unit test code. 343 """ 344 examples1 = [['p1', 0.1, 1, 0.1, 0], ['p2', 0.1, 0, 0.1, 1], ['p3', 0.1, 0, 1.1, 2], 345 ['p4', 0.1, 1, 1.1, 2], ['p5', 1.1, 0, 0.1, 2], ['p6', 1.1, 0, 1.1, 2], 346 ['p7', 1.1, 1, 0.1, 2], ['p8', 1.1, 1, 1.1, 0]] 347 attrs = list(range(1, len(examples1[0]) - 1)) 348 nPossibleVals = [0, 0, 2, 0, 3] 349 boundsPerVar = [0, 1, 0, 1, 0] 350 351 t1 = QuantTreeBoot(examples1, attrs, nPossibleVals, boundsPerVar) 352 t1.Print() 353 t1.Pickle('test_data/QuantTree2.pkl') 354 355 for example in examples1: 356 print(example, t1.ClassifyExample(example)) 357 358 359 if __name__ == "__main__": # pragma: nocover 360 TestTree() 361 TestQuantTree() 362 # TestQuantTree2() 363