1
2
3
4
5
6
7
8
9
10
11 """Command line tool to construct an enrichment plot from saved composite models
12
13 Usage: EnrichPlot [optional args] -d dbname -t tablename <models>
14
15 Required Arguments:
16 -d "dbName": the name of the database for screening
17
18 -t "tablename": provide the name of the table with the data to be screened
19
20 <models>: file name(s) of pickled composite model(s).
21 If the -p argument is also provided (see below), this argument is ignored.
22
23 Optional Arguments:
24 - -a "list": the list of result codes to be considered active. This will be
25 eval'ed, so be sure that it evaluates as a list or sequence of
26 integers. For example, -a "[1,2]" will consider activity values 1 and 2
27 to be active
28
29 - --enrich "list": identical to the -a argument above.
30
31 - --thresh: sets a threshold for the plot. If the confidence falls below
32 this value, picking will be terminated
33
34 - -H: screen only the hold out set (works only if a version of
35 BuildComposite more recent than 1.2.2 was used).
36
37 - -T: screen only the training set (works only if a version of
38 BuildComposite more recent than 1.2.2 was used).
39
40 - -S: shuffle activity values before screening
41
42 - -R: randomize activity values before screening
43
44 - -F *filter frac*: filters the data before training to change the
45 distribution of activity values in the training set. *filter frac*
46 is the fraction of the training set that should have the target value.
47 **See note in BuildComposite help about data filtering**
48
49 - -v *filter value*: filters the data before training to change the
50 distribution of activity values in the training set. *filter value*
51 is the target value to use in filtering.
52 **See note in BuildComposite help about data filtering**
53
54 - -p "tableName": provides the name of a db table containing the
55 models to be screened. If you use this argument, you should also
56 use the -N argument (below) to specify a note value.
57
58 - -N "note": provides a note to be used to pull models from a db table.
59
60 - --plotFile "filename": writes the data to an output text file (filename.dat)
61 and creates a gnuplot input file (filename.gnu) to plot it
62
63 - --showPlot: causes the gnuplot plot constructed using --plotFile to be
64 displayed in gnuplot.
65
66 """
67
68
69 from __future__ import print_function
70
71 import sys
72
73 import numpy
74
75 from rdkit import DataStructs
76 from rdkit import RDConfig
77 from rdkit.Dbase.DbConnection import DbConnect
78 from rdkit.ML import CompositeRun
79 from rdkit.ML.Data import DataUtils, SplitData, Stats
80 from rdkit.six import PY3
81 from rdkit.six.moves import cPickle
82 from rdkit.six.moves import input
83
84
85 __VERSION_STRING = "2.4.0"
86
87 if PY3:
88
90 return (t1 < t2) * -1 or (t1 > t2) * 1
91
92
93 -def message(msg, noRet=0, dest=sys.stderr):
94 """ emits messages to _sys.stderr_
95 override this in modules which import this one to redirect output
96
97 **Arguments**
98
99 - msg: the string to be displayed
100
101 """
102 if noRet:
103 dest.write('%s ' % (msg))
104 else:
105 dest.write('%s\n' % (msg))
106
107
108 -def error(msg, dest=sys.stderr):
109 """ emits messages to _sys.stderr_
110 override this in modules which import this one to redirect output
111
112 **Arguments**
113
114 - msg: the string to be displayed
115
116 """
117 sys.stderr.write('ERROR: %s\n' % (msg))
118
119
120 -def ScreenModel(mdl, descs, data, picking=[1], indices=[], errorEstimate=0):
121 """ collects the results of screening an individual composite model that match
122 a particular value
123
124 **Arguments**
125
126 - mdl: the composite model
127
128 - descs: a list of descriptor names corresponding to the data set
129
130 - data: the data set, a list of points to be screened.
131
132 - picking: (Optional) a list of values that are to be collected.
133 For examples, if you want an enrichment plot for picking the values
134 1 and 2, you'd having picking=[1,2].
135
136 **Returns**
137
138 a list of 4-tuples containing:
139
140 - the id of the point
141
142 - the true result (from the data set)
143
144 - the predicted result
145
146 - the confidence value for the prediction
147
148 """
149 mdl.SetInputOrder(descs)
150
151 for j in range(len(mdl)):
152 tmp = mdl.GetModel(j)
153 if hasattr(tmp, '_trainIndices') and not isinstance(tmp._trainIndices, dict):
154 tis = {}
155 if hasattr(tmp, '_trainIndices'):
156 for v in tmp._trainIndices:
157 tis[v] = 1
158 tmp._trainIndices = tis
159
160 res = []
161 if mdl.GetQuantBounds():
162 needsQuant = 1
163 else:
164 needsQuant = 0
165
166 if not indices:
167 indices = list(range(len(data)))
168 nTrueActives = 0
169 for i in indices:
170 if errorEstimate:
171 use = []
172 for j in range(len(mdl)):
173 tmp = mdl.GetModel(j)
174 if not tmp._trainIndices.get(i, 0):
175 use.append(j)
176 else:
177 use = None
178 pt = data[i]
179 pred, conf = mdl.ClassifyExample(pt, onlyModels=use)
180 if needsQuant:
181 pt = mdl.QuantizeActivity(pt[:])
182 trueRes = pt[-1]
183 if trueRes in picking:
184 nTrueActives += 1
185 if pred in picking:
186 res.append((pt[0], trueRes, pred, conf))
187 return nTrueActives, res
188
189
191 """ Accumulates the data for the enrichment plot for a single model
192
193 **Arguments**
194
195 - predictions: a list of 3-tuples (as returned by _ScreenModels_)
196
197 - thresh: a threshold for the confidence level. Anything below
198 this threshold will not be considered
199
200 - sortIt: toggles sorting on confidence levels
201
202
203 **Returns**
204
205 - a list of 3-tuples:
206
207 - the id of the active picked here
208
209 - num actives found so far
210
211 - number of picks made so far
212
213 """
214 if sortIt:
215 predictions.sort(lambda x, y: cmp(y[3], x[3]))
216 res = []
217 nCorrect = 0
218 nPts = 0
219 for i in range(len(predictions)):
220 ID, real, pred, conf = predictions[i]
221 if conf > thresh:
222 if pred == real:
223 nCorrect += 1
224 nPts += 1
225 res.append((ID, nCorrect, nPts))
226
227 return res
228
229
230 -def MakePlot(details, final, counts, pickVects, nModels, nTrueActs=-1):
231 if not hasattr(details, 'plotFile') or not details.plotFile:
232 return
233
234 dataFileName = '%s.dat' % (details.plotFile)
235 outF = open(dataFileName, 'w+')
236 i = 0
237 while i < len(final) and counts[i] != 0:
238 if nModels > 1:
239 _, sd = Stats.MeanAndDev(pickVects[i])
240 confInterval = Stats.GetConfidenceInterval(sd, len(pickVects[i]), level=90)
241 outF.write('%d %f %f %d %f\n' % (i + 1, final[i][0] / counts[i], final[i][1] / counts[i],
242 counts[i], confInterval))
243 else:
244 outF.write('%d %f %f %d\n' % (i + 1, final[i][0] / counts[i], final[i][1] / counts[i],
245 counts[i]))
246 i += 1
247 outF.close()
248 plotFileName = '%s.gnu' % (details.plotFile)
249 gnuF = open(plotFileName, 'w+')
250 gnuHdr = """# Generated by EnrichPlot.py version: %s
251 set size square 0.7
252 set xr [0:]
253 set data styl points
254 set ylab 'Num Correct Picks'
255 set xlab 'Num Picks'
256 set grid
257 set nokey
258 set term postscript enh color solid "Helvetica" 16
259 set term X
260 """ % (__VERSION_STRING)
261 print(gnuHdr, file=gnuF)
262 if nTrueActs > 0:
263 print('set yr [0:%d]' % nTrueActs, file=gnuF)
264 print('plot x with lines', file=gnuF)
265 if nModels > 1:
266 everyGap = i / 20
267 print('replot "%s" using 1:2 with lines,' % (dataFileName), end='', file=gnuF)
268 print('"%s" every %d using 1:2:5 with yerrorbars' % (dataFileName, everyGap), file=gnuF)
269 else:
270 print('replot "%s" with points' % (dataFileName), file=gnuF)
271 gnuF.close()
272
273 if hasattr(details, 'showPlot') and details.showPlot:
274 try:
275 from Gnuplot import Gnuplot
276 p = Gnuplot()
277 p('load "%s"' % (plotFileName))
278 input('press return to continue...\n')
279 except Exception:
280 import traceback
281 traceback.print_exc()
282
283
285 """ displays a usage message and exits """
286 sys.stderr.write(__doc__)
287 sys.exit(-1)
288
289
290 if __name__ == '__main__':
291 import getopt
292 try:
293 args, extras = getopt.getopt(sys.argv[1:], 'd:t:a:N:p:cSTHF:v:',
294 ('thresh=', 'plotFile=', 'showPlot', 'pickleCol=', 'OOB', 'noSort',
295 'pickBase=', 'doROC', 'rocThresh=', 'enrich='))
296 except Exception:
297 import traceback
298 traceback.print_exc()
299 Usage()
300
301 details = CompositeRun.CompositeRun()
302 CompositeRun.SetDefaults(details)
303
304 details.activeTgt = [1]
305 details.doTraining = 0
306 details.doHoldout = 0
307 details.dbTableName = ''
308 details.plotFile = ''
309 details.showPlot = 0
310 details.pickleCol = -1
311 details.errorEstimate = 0
312 details.sortIt = 1
313 details.pickBase = ''
314 details.doROC = 0
315 details.rocThresh = -1
316 for arg, val in args:
317 if arg == '-d':
318 details.dbName = val
319 if arg == '-t':
320 details.dbTableName = val
321 elif arg == '-a' or arg == '--enrich':
322 details.activeTgt = eval(val)
323 if not isinstance(details.activeTgt, (tuple, list)):
324
325 details.activeTgt = (details.activeTgt, )
326
327 elif arg == '--thresh':
328 details.threshold = float(val)
329 elif arg == '-N':
330 details.note = val
331 elif arg == '-p':
332 details.persistTblName = val
333 elif arg == '-S':
334 details.shuffleActivities = 1
335 elif arg == '-H':
336 details.doTraining = 0
337 details.doHoldout = 1
338 elif arg == '-T':
339 details.doTraining = 1
340 details.doHoldout = 0
341 elif arg == '-F':
342 details.filterFrac = float(val)
343 elif arg == '-v':
344 details.filterVal = float(val)
345 elif arg == '--plotFile':
346 details.plotFile = val
347 elif arg == '--showPlot':
348 details.showPlot = 1
349 elif arg == '--pickleCol':
350 details.pickleCol = int(val) - 1
351 elif arg == '--OOB':
352 details.errorEstimate = 1
353 elif arg == '--noSort':
354 details.sortIt = 0
355 elif arg == '--doROC':
356 details.doROC = 1
357 elif arg == '--rocThresh':
358 details.rocThresh = int(val)
359 elif arg == '--pickBase':
360 details.pickBase = val
361
362 if not details.dbName or not details.dbTableName:
363 Usage()
364 print('*******Please provide both the -d and -t arguments')
365
366 message('Building Data set\n')
367 dataSet = DataUtils.DBToData(details.dbName, details.dbTableName, user=RDConfig.defaultDBUser,
368 password=RDConfig.defaultDBPassword, pickleCol=details.pickleCol,
369 pickleClass=DataStructs.ExplicitBitVect)
370
371 descs = dataSet.GetVarNames()
372 nPts = dataSet.GetNPts()
373 message('npts: %d\n' % (nPts))
374 final = numpy.zeros((nPts, 2), numpy.float)
375 counts = numpy.zeros(nPts, numpy.integer)
376 selPts = [None] * nPts
377
378 models = []
379 if details.persistTblName:
380 conn = DbConnect(details.dbName, details.persistTblName)
381 message('-> Retrieving models from database')
382 curs = conn.GetCursor()
383 curs.execute("select model from %s where note='%s'" % (details.persistTblName, details.note))
384 message('-> Reconstructing models')
385 try:
386 blob = curs.fetchone()
387 except Exception:
388 blob = None
389 while blob:
390 message(' Building model %d' % len(models))
391 blob = blob[0]
392 try:
393 models.append(cPickle.loads(str(blob)))
394 except Exception:
395 import traceback
396 traceback.print_exc()
397 print('Model failed')
398 else:
399 message(' <-Done')
400 try:
401 blob = curs.fetchone()
402 except Exception:
403 blob = None
404 curs = None
405 else:
406 for modelName in extras:
407 try:
408 model = cPickle.load(open(modelName, 'rb'))
409 except Exception:
410 import traceback
411 print('problems with model %s:' % modelName)
412 traceback.print_exc()
413 else:
414 models.append(model)
415 nModels = len(models)
416 pickVects = {}
417 halfwayPts = [1e8] * len(models)
418 for whichModel, model in enumerate(models):
419 tmpD = dataSet
420 try:
421 seed = model._randomSeed
422 except AttributeError:
423 pass
424 else:
425 DataUtils.InitRandomNumbers(seed)
426 if details.shuffleActivities:
427 DataUtils.RandomizeActivities(tmpD, shuffle=1)
428 if hasattr(model, '_splitFrac') and (details.doHoldout or details.doTraining):
429 trainIdx, testIdx = SplitData.SplitIndices(tmpD.GetNPts(), model._splitFrac, silent=1)
430 if details.filterFrac != 0.0:
431 trainFilt, temp = DataUtils.FilterData(tmpD, details.filterVal, details.filterFrac, -1,
432 indicesToUse=trainIdx, indicesOnly=1)
433 testIdx += temp
434 trainIdx = trainFilt
435 if details.doTraining:
436 testIdx, trainIdx = trainIdx, testIdx
437 else:
438 testIdx = list(range(tmpD.GetNPts()))
439
440 message('screening %d examples' % (len(testIdx)))
441 nTrueActives, screenRes = ScreenModel(model, descs, tmpD, picking=details.activeTgt,
442 indices=testIdx, errorEstimate=details.errorEstimate)
443 message('accumulating')
444 runningCounts = AccumulateCounts(screenRes, sortIt=details.sortIt, thresh=details.threshold)
445 if details.pickBase:
446 pickFile = open('%s.%d.picks' % (details.pickBase, whichModel + 1), 'w+')
447 else:
448 pickFile = None
449
450 for i, entry in enumerate(runningCounts):
451 entry = runningCounts[i]
452 selPts[i] = entry[0]
453 final[i][0] += entry[1]
454 final[i][1] += entry[2]
455 v = pickVects.get(i, [])
456 v.append(entry[1])
457 pickVects[i] = v
458 counts[i] += 1
459 if pickFile:
460 pickFile.write('%s\n' % (entry[0]))
461 if entry[1] >= nTrueActives / 2 and entry[2] < halfwayPts[whichModel]:
462 halfwayPts[whichModel] = entry[2]
463 message('Halfway point: %d\n' % halfwayPts[whichModel])
464
465 if details.plotFile:
466 MakePlot(details, final, counts, pickVects, nModels, nTrueActs=nTrueActives)
467 else:
468 if nModels > 1:
469 print('#Index\tAvg_num_correct\tConf90Pct\tAvg_num_picked\tNum_picks\tlast_selection')
470 else:
471 print('#Index\tAvg_num_correct\tAvg_num_picked\tNum_picks\tlast_selection')
472
473 i = 0
474 while i < nPts and counts[i] != 0:
475 if nModels > 1:
476 mean, sd = Stats.MeanAndDev(pickVects[i])
477 confInterval = Stats.GetConfidenceInterval(sd, len(pickVects[i]), level=90)
478 print('%d\t%f\t%f\t%f\t%d\t%s' % (i + 1, final[i][0] / counts[i], confInterval,
479 final[i][1] / counts[i], counts[i], str(selPts[i])))
480 else:
481 print('%d\t%f\t%f\t%d\t%s' % (i + 1, final[i][0] / counts[i], final[i][1] / counts[i],
482 counts[i], str(selPts[i])))
483 i += 1
484
485 mean, sd = Stats.MeanAndDev(halfwayPts)
486 print('Halfway point: %.2f(%.2f)' % (mean, sd))
487