1
2
3
4
5
6 """
7
8 """
9 from __future__ import print_function
10
11 import copy
12 import random
13
14 import numpy
15
16 from rdkit.DataStructs.VectCollection import VectCollection
17 from rdkit.ML import InfoTheory
18 from rdkit.ML.DecTree import SigTree
19
20 try:
21 from rdkit.ML.FeatureSelect import CMIM
22 except ImportError:
23 CMIM = None
24
25
27 """ Generates a random subset of a group of indices
28
29 **Arguments**
30
31 - nToInclude: the size of the desired set
32
33 - nBits: the maximum index to be included in the set
34
35 **Returns**
36
37 a list of indices
38
39 """
40
41
42 return random.sample(range(nBits), nToInclude)
43
44
45 -def BuildSigTree(examples, nPossibleRes, ensemble=None, random=0,
46 metric=InfoTheory.InfoType.BIASENTROPY, biasList=[1], depth=0, maxDepth=-1,
47 useCMIM=0, allowCollections=False, verbose=0, **kwargs):
48 """
49 **Arguments**
50
51 - examples: the examples to be classified. Each example
52 should be a sequence at least three entries long, with
53 entry 0 being a label, entry 1 a BitVector and entry -1
54 an activity value
55
56 - nPossibleRes: the number of result codes possible
57
58 - ensemble: (optional) if this argument is provided, it
59 should be a sequence which is used to limit the bits
60 which are actually considered as potential descriptors.
61 The default is None (use all bits).
62
63 - random: (optional) If this argument is nonzero, it
64 specifies the number of bits to be randomly selected
65 for consideration at this node (i.e. this toggles the
66 growth of Random Trees).
67 The default is 0 (no random descriptor selection)
68
69 - metric: (optional) This is an _InfoTheory.InfoType_ and
70 sets the metric used to rank the bits.
71 The default is _InfoTheory.InfoType.BIASENTROPY_
72
73 - biasList: (optional) If provided, this provides a bias
74 list for the bit ranker.
75 See the _InfoTheory.InfoBitRanker_ docs for an explanation
76 of bias.
77 The default value is [1], which biases towards actives.
78
79 - maxDepth: (optional) the maximum depth to which the tree
80 will be grown
81 The default is -1 (no depth limit).
82
83 - useCMIM: (optional) if this is >0, the CMIM algorithm
84 (conditional mutual information maximization) will be
85 used to select the descriptors used to build the trees.
86 The value of the variable should be set to the number
87 of descriptors to be used. This option and the
88 ensemble option are mutually exclusive (CMIM will not be
89 used if the ensemble is set), but it happily coexsts
90 with the random argument (to only consider random subsets
91 of the top N CMIM bits)
92 The default is 0 (do not use CMIM)
93
94 - depth: (optional) the current depth in the tree
95 This is used in the recursion and should not be set
96 by the client.
97
98 **Returns**
99
100 a SigTree.SigTreeNode with the root of the decision tree
101
102 """
103 if verbose:
104 print(' ' * depth, 'Build')
105 tree = SigTree.SigTreeNode(None, 'node', level=depth)
106 tree.SetData(-666)
107
108
109
110
111 resCodes = [int(x[-1]) for x in examples]
112
113 counts = [0] * nPossibleRes
114 for res in resCodes:
115 counts[res] += 1
116
117
118 nzCounts = numpy.nonzero(counts)[0]
119 if verbose:
120 print(' ' * depth, '\tcounts:', counts)
121 if len(nzCounts) == 1:
122
123
124
125 res = nzCounts[0]
126 tree.SetLabel(res)
127 tree.SetName(str(res))
128 tree.SetTerminal(1)
129 elif maxDepth >= 0 and depth > maxDepth:
130
131
132
133
134 v = numpy.argmax(counts)
135 tree.SetLabel(v)
136 tree.SetName('%d?' % v)
137 tree.SetTerminal(1)
138 else:
139
140
141 fp = examples[0][1]
142 nBits = fp.GetNumBits()
143 ranker = InfoTheory.InfoBitRanker(nBits, nPossibleRes, metric)
144 if biasList:
145 ranker.SetBiasList(biasList)
146 if CMIM is not None and useCMIM > 0 and not ensemble:
147 ensemble = CMIM.SelectFeatures(examples, useCMIM, bvCol=1)
148 if random:
149 if ensemble:
150 if len(ensemble) > random:
151 picks = _GenerateRandomEnsemble(random, len(ensemble))
152 availBits = list(numpy.take(ensemble, picks))
153 else:
154 availBits = list(range(len(ensemble)))
155 else:
156 availBits = _GenerateRandomEnsemble(random, nBits)
157 else:
158 availBits = None
159 if availBits:
160 ranker.SetMaskBits(availBits)
161
162
163 useCollections = isinstance(examples[0][1], VectCollection)
164 for example in examples:
165
166 if not useCollections:
167 ranker.AccumulateVotes(example[1], example[-1])
168 else:
169 example[1].Reset()
170 ranker.AccumulateVotes(example[1].orVect, example[-1])
171
172 try:
173 bitInfo = ranker.GetTopN(1)[0]
174 best = int(bitInfo[0])
175 gain = bitInfo[1]
176 except Exception:
177 import traceback
178 traceback.print_exc()
179 print('get top n failed')
180 gain = -1.0
181 if gain <= 0.0:
182 v = numpy.argmax(counts)
183 tree.SetLabel(v)
184 tree.SetName('?%d?' % v)
185 tree.SetTerminal(1)
186 return tree
187 best = int(bitInfo[0])
188
189 if verbose:
190 print(' ' * depth, '\tbest:', bitInfo)
191
192 tree.SetName('Bit-%d' % (best))
193 tree.SetLabel(best)
194
195 tree.SetTerminal(0)
196
197
198
199 onExamples = []
200 offExamples = []
201 for example in examples:
202 if example[1][best]:
203 if allowCollections and useCollections:
204 sig = copy.copy(example[1])
205 sig.DetachVectsNotMatchingBit(best)
206 ex = [example[0], sig]
207 if len(example) > 2:
208 ex.extend(example[2:])
209 example = ex
210 onExamples.append(example)
211 else:
212 offExamples.append(example)
213
214 for ex in (offExamples, onExamples):
215 if len(ex) == 0:
216 v = numpy.argmax(counts)
217 tree.AddChild('%d??' % v, label=v, data=0.0, isTerminal=1)
218 else:
219 child = BuildSigTree(ex, nPossibleRes, random=random, ensemble=ensemble, metric=metric,
220 biasList=biasList, depth=depth + 1, maxDepth=maxDepth, verbose=verbose)
221 if child is None:
222 v = numpy.argmax(counts)
223 tree.AddChild('%d???' % v, label=v, data=0.0, isTerminal=1)
224 else:
225 tree.AddChildNode(child)
226 return tree
227
228
229 -def SigTreeBuilder(examples, attrs, nPossibleVals, initialVar=None, ensemble=None,
230 randomDescriptors=0, **kwargs):
231 nRes = nPossibleVals[-1]
232 return BuildSigTree(examples, nRes, random=randomDescriptors, **kwargs)
233