1
2
3
4
5
6 """
7
8 """
9 from __future__ import print_function
10 import numpy
11 import random
12 from rdkit.ML.DecTree import QuantTree, ID3
13 from rdkit.ML.InfoTheory import entropy
14 from rdkit.ML.Data import Quantize
15 from rdkit.six.moves import range
16
17
18 -def FindBest(resCodes, examples, nBoundsPerVar, nPossibleRes, nPossibleVals, attrs, exIndices=None,
19 **kwargs):
20 bestGain = -1e6
21 best = -1
22 bestBounds = []
23
24 if exIndices is None:
25 exIndices = list(range(len(examples)))
26
27 if not len(exIndices):
28 return best, bestGain, bestBounds
29
30 nToTake = kwargs.get('randomDescriptors', 0)
31 if nToTake > 0:
32 nAttrs = len(attrs)
33 if nToTake < nAttrs:
34 ids = list(range(nAttrs))
35 random.shuffle(ids, random=random.random)
36 tmp = [attrs[x] for x in ids[:nToTake]]
37 attrs = tmp
38
39 for var in attrs:
40 nBounds = nBoundsPerVar[var]
41 if nBounds > 0:
42
43 try:
44 vTable = [examples[x][var] for x in exIndices]
45 except IndexError:
46 print('index error retrieving variable: %d' % var)
47 raise
48 qBounds, gainHere = Quantize.FindVarMultQuantBounds(vTable, nBounds, resCodes, nPossibleRes)
49
50 elif nBounds == 0:
51 vTable = ID3.GenVarTable((examples[x] for x in exIndices), nPossibleVals, [var])[0]
52 gainHere = entropy.InfoGain(vTable)
53 qBounds = []
54 else:
55 gainHere = -1e6
56 qBounds = []
57 if gainHere > bestGain:
58 bestGain = gainHere
59 bestBounds = qBounds
60 best = var
61 elif bestGain == gainHere:
62 if len(qBounds) < len(bestBounds):
63 best = var
64 bestBounds = qBounds
65 if best == -1:
66 print('best unaltered')
67 print('\tattrs:', attrs)
68 print('\tnBounds:', numpy.take(nBoundsPerVar, attrs))
69 print('\texamples:')
70 for example in (examples[x] for x in exIndices):
71 print('\t\t', example)
72
73 if 0:
74 print('BEST:', len(exIndices), best, bestGain, bestBounds)
75 if (len(exIndices) < 10):
76 print(len(exIndices), len(resCodes), len(examples))
77 exs = [examples[x] for x in exIndices]
78 vals = [x[best] for x in exs]
79 sortIdx = numpy.argsort(vals)
80 sortVals = [exs[x] for x in sortIdx]
81 sortResults = [resCodes[x] for x in sortIdx]
82 for i in range(len(vals)):
83 print(' ', i, ['%.4f' % x for x in sortVals[i][1:-1]], sortResults[i])
84 return best, bestGain, bestBounds
85
86
87 -def BuildQuantTree(examples, target, attrs, nPossibleVals, nBoundsPerVar, depth=0, maxDepth=-1,
88 exIndices=None, **kwargs):
89 """
90 **Arguments**
91
92 - examples: a list of lists (nInstances x nVariables+1) of variable
93 values + instance values
94
95 - target: an int
96
97 - attrs: a list of ints indicating which variables can be used in the tree
98
99 - nPossibleVals: a list containing the number of possible values of
100 every variable.
101
102 - nBoundsPerVar: the number of bounds to include for each variable
103
104 - depth: (optional) the current depth in the tree
105
106 - maxDepth: (optional) the maximum depth to which the tree
107 will be grown
108 **Returns**
109
110 a QuantTree.QuantTreeNode with the decision tree
111
112 **NOTE:** This code cannot bootstrap (start from nothing...)
113 use _QuantTreeBoot_ (below) for that.
114 """
115 tree = QuantTree.QuantTreeNode(None, 'node')
116 tree.SetData(-666)
117 nPossibleRes = nPossibleVals[-1]
118
119 if exIndices is None:
120 exIndices = list(range(len(examples)))
121
122
123 resCodes = [int(x[-1]) for x in (examples[y] for y in exIndices)]
124 counts = [0] * nPossibleRes
125 for res in resCodes:
126 counts[res] += 1
127 nzCounts = numpy.nonzero(counts)[0]
128
129 if len(nzCounts) == 1:
130
131
132
133 res = nzCounts[0]
134 tree.SetLabel(res)
135 tree.SetName(str(res))
136 tree.SetTerminal(1)
137 elif len(attrs) == 0 or (maxDepth >= 0 and depth > maxDepth):
138
139
140
141
142 v = numpy.argmax(counts)
143 tree.SetLabel(v)
144 tree.SetName('%d?' % v)
145 tree.SetTerminal(1)
146 else:
147
148 best, _, bestBounds = FindBest(resCodes, examples, nBoundsPerVar, nPossibleRes, nPossibleVals,
149 attrs, exIndices=exIndices, **kwargs)
150
151 nextAttrs = attrs[:]
152 if not kwargs.get('recycleVars', 0):
153 nextAttrs.remove(best)
154
155
156 tree.SetName('Var: %d' % (best))
157 tree.SetLabel(best)
158 tree.SetQuantBounds(bestBounds)
159 tree.SetTerminal(0)
160
161
162
163 indices = exIndices[:]
164 if len(bestBounds) > 0:
165 for bound in bestBounds:
166 nextExamples = []
167 for index in indices[:]:
168 ex = examples[index]
169 if ex[best] < bound:
170 nextExamples.append(index)
171 indices.remove(index)
172
173 if len(nextExamples) == 0:
174
175
176
177 v = numpy.argmax(counts)
178 tree.AddChild('%d' % v, label=v, data=0.0, isTerminal=1)
179 else:
180
181 tree.AddChildNode(
182 BuildQuantTree(examples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=depth + 1,
183 maxDepth=maxDepth, exIndices=nextExamples, **kwargs))
184
185 nextExamples = []
186 for index in indices:
187 nextExamples.append(index)
188 if len(nextExamples) == 0:
189 v = numpy.argmax(counts)
190 tree.AddChild('%d' % v, label=v, data=0.0, isTerminal=1)
191 else:
192 tree.AddChildNode(
193 BuildQuantTree(examples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=depth + 1,
194 maxDepth=maxDepth, exIndices=nextExamples, **kwargs))
195 else:
196 for val in range(nPossibleVals[best]):
197 nextExamples = []
198 for idx in exIndices:
199 if examples[idx][best] == val:
200 nextExamples.append(idx)
201 if len(nextExamples) == 0:
202 v = numpy.argmax(counts)
203 tree.AddChild('%d' % v, label=v, data=0.0, isTerminal=1)
204 else:
205 tree.AddChildNode(
206 BuildQuantTree(examples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=depth + 1,
207 maxDepth=maxDepth, exIndices=nextExamples, **kwargs))
208 return tree
209
210
211 -def QuantTreeBoot(examples, attrs, nPossibleVals, nBoundsPerVar, initialVar=None, maxDepth=-1,
212 **kwargs):
213 """ Bootstrapping code for the QuantTree
214
215 If _initialVar_ is not set, the algorithm will automatically
216 choose the first variable in the tree (the standard greedy
217 approach). Otherwise, _initialVar_ will be used as the first
218 split.
219
220 """
221 attrs = list(attrs)
222 for i in range(len(nBoundsPerVar)):
223 if nBoundsPerVar[i] == -1 and i in attrs:
224 attrs.remove(i)
225
226 tree = QuantTree.QuantTreeNode(None, 'node')
227 nPossibleRes = nPossibleVals[-1]
228 tree._nResultCodes = nPossibleRes
229
230 resCodes = [int(x[-1]) for x in examples]
231 counts = [0] * nPossibleRes
232 for res in resCodes:
233 counts[res] += 1
234 if initialVar is None:
235 best, gainHere, qBounds = FindBest(resCodes, examples, nBoundsPerVar, nPossibleRes,
236 nPossibleVals, attrs, **kwargs)
237 else:
238 best = initialVar
239 if nBoundsPerVar[best] > 0:
240 vTable = map(lambda x, z=best: x[z], examples)
241 qBounds, gainHere = Quantize.FindVarMultQuantBounds(vTable, nBoundsPerVar[best], resCodes,
242 nPossibleRes)
243 elif nBoundsPerVar[best] == 0:
244 vTable = ID3.GenVarTable(examples, nPossibleVals, [best])[0]
245 gainHere = entropy.InfoGain(vTable)
246 qBounds = []
247 else:
248 gainHere = -1e6
249 qBounds = []
250
251 tree.SetName('Var: %d' % (best))
252 tree.SetData(gainHere)
253 tree.SetLabel(best)
254 tree.SetTerminal(0)
255 tree.SetQuantBounds(qBounds)
256 nextAttrs = list(attrs)
257 if not kwargs.get('recycleVars', 0):
258 nextAttrs.remove(best)
259
260 indices = list(range(len(examples)))
261 if len(qBounds) > 0:
262 for bound in qBounds:
263 nextExamples = []
264 for index in list(indices):
265 ex = examples[index]
266 if ex[best] < bound:
267 nextExamples.append(ex)
268 indices.remove(index)
269
270 if len(nextExamples):
271 tree.AddChildNode(
272 BuildQuantTree(nextExamples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=1,
273 maxDepth=maxDepth, **kwargs))
274 else:
275 v = numpy.argmax(counts)
276 tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1)
277
278 nextExamples = []
279 for index in indices:
280 nextExamples.append(examples[index])
281 if len(nextExamples) != 0:
282 tree.AddChildNode(
283 BuildQuantTree(nextExamples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=1,
284 maxDepth=maxDepth, **kwargs))
285 else:
286 v = numpy.argmax(counts)
287 tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1)
288 else:
289 for val in range(nPossibleVals[best]):
290 nextExamples = []
291 for example in examples:
292 if example[best] == val:
293 nextExamples.append(example)
294 if len(nextExamples) != 0:
295 tree.AddChildNode(
296 BuildQuantTree(nextExamples, best, nextAttrs, nPossibleVals, nBoundsPerVar, depth=1,
297 maxDepth=maxDepth, **kwargs))
298 else:
299 v = numpy.argmax(counts)
300 tree.AddChild('%d??' % (v), label=v, data=0.0, isTerminal=1)
301 return tree
302
303
305 """ testing code for named trees
306
307 """
308 examples1 = [['p1', 0, 1, 0, 0], ['p2', 0, 0, 0, 1], ['p3', 0, 0, 1, 2], ['p4', 0, 1, 1, 2],
309 ['p5', 1, 0, 0, 2], ['p6', 1, 0, 1, 2], ['p7', 1, 1, 0, 2], ['p8', 1, 1, 1, 0]]
310 attrs = list(range(1, len(examples1[0]) - 1))
311 nPossibleVals = [0, 2, 2, 2, 3]
312 t1 = ID3.ID3Boot(examples1, attrs, nPossibleVals, maxDepth=1)
313 t1.Print()
314
315
317 """ Testing code for named trees
318
319 The created pkl file is required by the unit test code.
320 """
321 examples1 = [['p1', 0, 1, 0.1, 0], ['p2', 0, 0, 0.1, 1], ['p3', 0, 0, 1.1, 2],
322 ['p4', 0, 1, 1.1, 2], ['p5', 1, 0, 0.1, 2], ['p6', 1, 0, 1.1, 2],
323 ['p7', 1, 1, 0.1, 2], ['p8', 1, 1, 1.1, 0]]
324 attrs = list(range(1, len(examples1[0]) - 1))
325 nPossibleVals = [0, 2, 2, 0, 3]
326 boundsPerVar = [0, 0, 0, 1, 0]
327
328 print('base')
329 t1 = QuantTreeBoot(examples1, attrs, nPossibleVals, boundsPerVar)
330 t1.Pickle('test_data/QuantTree1.pkl')
331 t1.Print()
332
333 print('depth limit')
334 t1 = QuantTreeBoot(examples1, attrs, nPossibleVals, boundsPerVar, maxDepth=1)
335 t1.Pickle('test_data/QuantTree1.pkl')
336 t1.Print()
337
338
340 """ testing code for named trees
341
342 The created pkl file is required by the unit test code.
343 """
344 examples1 = [['p1', 0.1, 1, 0.1, 0], ['p2', 0.1, 0, 0.1, 1], ['p3', 0.1, 0, 1.1, 2],
345 ['p4', 0.1, 1, 1.1, 2], ['p5', 1.1, 0, 0.1, 2], ['p6', 1.1, 0, 1.1, 2],
346 ['p7', 1.1, 1, 0.1, 2], ['p8', 1.1, 1, 1.1, 0]]
347 attrs = list(range(1, len(examples1[0]) - 1))
348 nPossibleVals = [0, 0, 2, 0, 3]
349 boundsPerVar = [0, 1, 0, 1, 0]
350
351 t1 = QuantTreeBoot(examples1, attrs, nPossibleVals, boundsPerVar)
352 t1.Print()
353 t1.Pickle('test_data/QuantTree2.pkl')
354
355 for example in examples1:
356 print(example, t1.ClassifyExample(example))
357
358
359 if __name__ == "__main__":
360 TestTree()
361 TestQuantTree()
362
363