1
2
3
4
5 from __future__ import print_function
6
7 import random
8
9 from rdkit import RDRandom
10
11 SeqTypes = (list, tuple)
12
13
14 -def SplitIndices(nPts, frac, silent=1, legacy=0, replacement=0):
15 """ splits a set of indices into a data set into 2 pieces
16
17 **Arguments**
18
19 - nPts: the total number of points
20
21 - frac: the fraction of the data to be put in the first data set
22
23 - silent: (optional) toggles display of stats
24
25 - legacy: (optional) use the legacy splitting approach
26
27 - replacement: (optional) use selection with replacement
28
29 **Returns**
30
31 a 2-tuple containing the two sets of indices.
32
33 **Notes**
34
35 - the _legacy_ splitting approach uses randomly-generated floats
36 and compares them to _frac_. This is provided for
37 backwards-compatibility reasons.
38
39 - the default splitting approach uses a random permutation of
40 indices which is split into two parts.
41
42 - selection with replacement can generate duplicates.
43
44
45 **Usage**:
46
47 We'll start with a set of indices and pick from them using
48 the three different approaches:
49 >>> from rdkit.ML.Data import DataUtils
50
51 The base approach always returns the same number of compounds in
52 each set and has no duplicates:
53 >>> DataUtils.InitRandomNumbers((23,42))
54 >>> test,train = SplitIndices(10,.5)
55 >>> test
56 [1, 5, 6, 4, 2]
57 >>> train
58 [3, 0, 7, 8, 9]
59
60 >>> test,train = SplitIndices(10,.5)
61 >>> test
62 [5, 2, 9, 8, 7]
63 >>> train
64 [6, 0, 3, 1, 4]
65
66
67 The legacy approach can return varying numbers, but still has no
68 duplicates. Note the indices come back ordered:
69 >>> DataUtils.InitRandomNumbers((23,42))
70 >>> test,train = SplitIndices(10,.5,legacy=1)
71 >>> test
72 [3, 5, 7, 8, 9]
73 >>> train
74 [0, 1, 2, 4, 6]
75
76 >>> test,train = SplitIndices(10,.5,legacy=1)
77 >>> test
78 [0, 1, 2, 3, 5, 8, 9]
79 >>> train
80 [4, 6, 7]
81
82 The replacement approach returns a fixed number in the training set,
83 a variable number in the test set and can contain duplicates in the
84 training set.
85 >>> DataUtils.InitRandomNumbers((23,42))
86 >>> test,train = SplitIndices(10,.5,replacement=1)
87 >>> test
88 [9, 9, 8, 0, 5]
89 >>> train
90 [1, 2, 3, 4, 6, 7]
91 >>> test,train = SplitIndices(10,.5,replacement=1)
92 >>> test
93 [4, 5, 1, 1, 4]
94 >>> train
95 [0, 2, 3, 6, 7, 8, 9]
96
97 """
98 if frac < 0. or frac > 1.:
99 raise ValueError('frac must be between 0.0 and 1.0 (frac=%f)' % (frac))
100
101 if replacement:
102 nTrain = int(nPts * frac)
103 resData = [None] * nTrain
104 resTest = []
105 for i in range(nTrain):
106 val = int(RDRandom.random() * nPts)
107 if val == nPts:
108 val = nPts - 1
109 resData[i] = val
110 for i in range(nPts):
111 if i not in resData:
112 resTest.append(i)
113 elif legacy:
114 resData = []
115 resTest = []
116 for i in range(nPts):
117 val = RDRandom.random()
118 if val < frac:
119 resData.append(i)
120 else:
121 resTest.append(i)
122 else:
123 perm = list(range(nPts))
124 random.shuffle(perm, random=random.random)
125 nTrain = int(nPts * frac)
126
127 resData = list(perm[:nTrain])
128 resTest = list(perm[nTrain:])
129
130 if not silent:
131 print('Training with %d (of %d) points.' % (len(resData), nPts))
132 print('\t%d points are in the hold-out set.' % (len(resTest)))
133 return resData, resTest
134
135
137 """ splits a data set into two pieces
138
139 **Arguments**
140
141 - data: a list of examples to be split
142
143 - frac: the fraction of the data to be put in the first data set
144
145 - silent: controls the amount of visual noise produced.
146
147 **Returns**
148
149 a 2-tuple containing the two new data sets.
150
151 """
152 if frac < 0. or frac > 1.:
153 raise ValueError('frac must be between 0.0 and 1.0')
154
155 nOrig = len(data)
156 train, test = SplitIndices(nOrig, frac, silent=1)
157 resData = [data[x] for x in train]
158 resTest = [data[x] for x in test]
159
160 if not silent:
161 print('Training with %d (of %d) points.' % (len(resData), nOrig))
162 print('\t%d points are in the hold-out set.' % (len(resTest)))
163 return resData, resTest
164
165
166 -def SplitDbData(conn, fracs, table='', fields='*', where='', join='', labelCol='', useActs=0,
167 nActs=2, actCol='', actBounds=[], silent=0):
168 """ "splits" a data set held in a DB by returning lists of ids
169
170 **Arguments**:
171
172 - conn: a DbConnect object
173
174 - frac: the split fraction. This can optionally be specified as a
175 sequence with a different fraction for each activity value.
176
177 - table,fields,where,join: (optional) SQL query parameters
178
179 - useActs: (optional) toggles splitting based on activities
180 (ensuring that a given fraction of each activity class ends
181 up in the hold-out set)
182 Defaults to 0
183
184 - nActs: (optional) number of possible activity values, only
185 used if _useActs_ is nonzero
186 Defaults to 2
187
188 - actCol: (optional) name of the activity column
189 Defaults to use the last column returned by the query
190
191 - actBounds: (optional) sequence of activity bounds
192 (for cases where the activity isn't quantized in the db)
193 Defaults to an empty sequence
194
195 - silent: controls the amount of visual noise produced.
196
197 **Usage**:
198
199 Set up the db connection, the simple tables we're using have actives with even
200 ids and inactives with odd ids:
201 >>> from rdkit.ML.Data import DataUtils
202 >>> from rdkit.Dbase.DbConnection import DbConnect
203 >>> from rdkit import RDConfig
204 >>> conn = DbConnect(RDConfig.RDTestDatabase)
205
206 Pull a set of points from a simple table... take 33% of all points:
207 >>> DataUtils.InitRandomNumbers((23,42))
208 >>> train,test = SplitDbData(conn,1./3.,'basic_2class')
209 >>> [str(x) for x in train]
210 ['id-7', 'id-6', 'id-2', 'id-8']
211
212 ...take 50% of actives and 50% of inactives:
213 >>> DataUtils.InitRandomNumbers((23,42))
214 >>> train,test = SplitDbData(conn,.5,'basic_2class',useActs=1)
215 >>> [str(x) for x in train]
216 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8']
217
218
219 Notice how the results came out sorted by activity
220
221 We can be asymmetrical: take 33% of actives and 50% of inactives:
222 >>> DataUtils.InitRandomNumbers((23,42))
223 >>> train,test = SplitDbData(conn,[.5,1./3.],'basic_2class',useActs=1)
224 >>> [str(x) for x in train]
225 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10']
226
227 And we can pull from tables with non-quantized activities by providing
228 activity quantization bounds:
229 >>> DataUtils.InitRandomNumbers((23,42))
230 >>> train,test = SplitDbData(conn,.5,'float_2class',useActs=1,actBounds=[1.0])
231 >>> [str(x) for x in train]
232 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8']
233
234 """
235 if not table:
236 table = conn.tableName
237 if actBounds and len(actBounds) != nActs - 1:
238 raise ValueError('activity bounds list length incorrect')
239 if useActs:
240 if type(fracs) not in SeqTypes:
241 fracs = tuple([fracs] * nActs)
242 for frac in fracs:
243 if frac < 0.0 or frac > 1.0:
244 raise ValueError('fractions must be between 0.0 and 1.0')
245 else:
246 if type(fracs) in SeqTypes:
247 frac = fracs[0]
248 if frac < 0.0 or frac > 1.0:
249 raise ValueError('fractions must be between 0.0 and 1.0')
250 else:
251 frac = fracs
252
253 colNames = conn.GetColumnNames(table=table, what=fields, join=join)
254 idCol = colNames[0]
255
256 if not useActs:
257
258 d = conn.GetData(table=table, fields=idCol, join=join)
259 ids = [x[0] for x in d]
260 nRes = len(ids)
261 train, test = SplitIndices(nRes, frac, silent=1)
262 trainPts = [ids[x] for x in train]
263 testPts = [ids[x] for x in test]
264 else:
265 trainPts = []
266 testPts = []
267 if not actCol:
268 actCol = colNames[-1]
269 whereBase = where.strip()
270 if whereBase.find('where') != 0:
271 whereBase = 'where ' + whereBase
272 if where:
273 whereBase += ' and '
274 for act in range(nActs):
275 frac = fracs[act]
276 if not actBounds:
277 whereTxt = whereBase + '%s=%d' % (actCol, act)
278 else:
279 whereTxt = whereBase
280 if act != 0:
281 whereTxt += '%s>=%f ' % (actCol, actBounds[act - 1])
282 if act < nActs - 1:
283 if act != 0:
284 whereTxt += 'and '
285 whereTxt += '%s<%f' % (actCol, actBounds[act])
286 d = conn.GetData(table=table, fields=idCol, join=join, where=whereTxt)
287 ids = [x[0] for x in d]
288 nRes = len(ids)
289 train, test = SplitIndices(nRes, frac, silent=1)
290 trainPts.extend([ids[x] for x in train])
291 testPts.extend([ids[x] for x in test])
292
293 return trainPts, testPts
294
295
296
297
298
299
301 import sys
302 import doctest
303 failed, _ = doctest.testmod(optionflags=doctest.ELLIPSIS, verbose=verbose)
304 sys.exit(failed)
305
306
307 if __name__ == '__main__':
308 _runDoctests()
309