Package rdkit :: Package ML :: Module EnrichPlot
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.EnrichPlot

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2002-2006  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved @@ 
  6  #  This file is part of the RDKit. 
  7  #  The contents are covered by the terms of the BSD license 
  8  #  which is included in the file license.txt, found at the root 
  9  #  of the RDKit source tree. 
 10  # 
 11  """Command line tool to construct an enrichment plot from saved composite models 
 12   
 13  Usage:  EnrichPlot [optional args] -d dbname -t tablename <models> 
 14   
 15  Required Arguments: 
 16    -d "dbName": the name of the database for screening 
 17   
 18    -t "tablename": provide the name of the table with the data to be screened 
 19   
 20    <models>: file name(s) of pickled composite model(s). 
 21       If the -p argument is also provided (see below), this argument is ignored. 
 22   
 23  Optional Arguments: 
 24    - -a "list": the list of result codes to be considered active.  This will be 
 25          eval'ed, so be sure that it evaluates as a list or sequence of 
 26          integers. For example, -a "[1,2]" will consider activity values 1 and 2 
 27          to be active 
 28   
 29    - --enrich "list": identical to the -a argument above. 
 30   
 31    - --thresh: sets a threshold for the plot.  If the confidence falls below 
 32            this value, picking will be terminated 
 33   
 34    - -H: screen only the hold out set (works only if a version of 
 35          BuildComposite more recent than 1.2.2 was used). 
 36   
 37    - -T: screen only the training set (works only if a version of 
 38          BuildComposite more recent than 1.2.2 was used). 
 39   
 40    - -S: shuffle activity values before screening 
 41   
 42    - -R: randomize activity values before screening 
 43   
 44    - -F *filter frac*: filters the data before training to change the 
 45       distribution of activity values in the training set.  *filter frac* 
 46       is the fraction of the training set that should have the target value. 
 47       **See note in BuildComposite help about data filtering** 
 48   
 49    - -v *filter value*: filters the data before training to change the 
 50       distribution of activity values in the training set. *filter value* 
 51       is the target value to use in filtering. 
 52       **See note in BuildComposite help about data filtering** 
 53   
 54    - -p "tableName": provides the name of a db table containing the 
 55        models to be screened.  If you use this argument, you should also 
 56        use the -N argument (below) to specify a note value. 
 57   
 58    - -N "note": provides a note to be used to pull models from a db table. 
 59   
 60    - --plotFile "filename": writes the data to an output text file (filename.dat) 
 61      and creates a gnuplot input file (filename.gnu) to plot it 
 62   
 63    - --showPlot: causes the gnuplot plot constructed using --plotFile to be 
 64      displayed in gnuplot. 
 65   
 66  """ 
 67  # from rdkit.Dbase.DbConnection import DbConnect 
 68   
 69  from __future__ import print_function 
 70   
 71  import sys 
 72   
 73  import numpy 
 74   
 75  from rdkit import DataStructs 
 76  from rdkit import RDConfig 
 77  from rdkit.Dbase.DbConnection import DbConnect 
 78  from rdkit.ML import CompositeRun 
 79  from rdkit.ML.Data import DataUtils, SplitData, Stats 
 80  from rdkit.six import PY3 
 81  from rdkit.six.moves import cPickle 
 82  from rdkit.six.moves import input 
 83   
 84   
 85  __VERSION_STRING = "2.4.0" 
 86   
 87  if PY3: 
 88   
89 - def cmp(t1, t2):
90 return (t1 < t2) * -1 or (t1 > t2) * 1
91 92
93 -def message(msg, noRet=0, dest=sys.stderr):
94 """ emits messages to _sys.stderr_ 95 override this in modules which import this one to redirect output 96 97 **Arguments** 98 99 - msg: the string to be displayed 100 101 """ 102 if noRet: 103 dest.write('%s ' % (msg)) 104 else: 105 dest.write('%s\n' % (msg))
106 107
108 -def error(msg, dest=sys.stderr):
109 """ emits messages to _sys.stderr_ 110 override this in modules which import this one to redirect output 111 112 **Arguments** 113 114 - msg: the string to be displayed 115 116 """ 117 sys.stderr.write('ERROR: %s\n' % (msg))
118 119
120 -def ScreenModel(mdl, descs, data, picking=[1], indices=[], errorEstimate=0):
121 """ collects the results of screening an individual composite model that match 122 a particular value 123 124 **Arguments** 125 126 - mdl: the composite model 127 128 - descs: a list of descriptor names corresponding to the data set 129 130 - data: the data set, a list of points to be screened. 131 132 - picking: (Optional) a list of values that are to be collected. 133 For examples, if you want an enrichment plot for picking the values 134 1 and 2, you'd having picking=[1,2]. 135 136 **Returns** 137 138 a list of 4-tuples containing: 139 140 - the id of the point 141 142 - the true result (from the data set) 143 144 - the predicted result 145 146 - the confidence value for the prediction 147 148 """ 149 mdl.SetInputOrder(descs) 150 151 for j in range(len(mdl)): 152 tmp = mdl.GetModel(j) 153 if hasattr(tmp, '_trainIndices') and not isinstance(tmp._trainIndices, dict): 154 tis = {} 155 if hasattr(tmp, '_trainIndices'): 156 for v in tmp._trainIndices: 157 tis[v] = 1 158 tmp._trainIndices = tis 159 160 res = [] 161 if mdl.GetQuantBounds(): 162 needsQuant = 1 163 else: 164 needsQuant = 0 165 166 if not indices: 167 indices = list(range(len(data))) 168 nTrueActives = 0 169 for i in indices: 170 if errorEstimate: 171 use = [] 172 for j in range(len(mdl)): 173 tmp = mdl.GetModel(j) 174 if not tmp._trainIndices.get(i, 0): 175 use.append(j) 176 else: 177 use = None 178 pt = data[i] 179 pred, conf = mdl.ClassifyExample(pt, onlyModels=use) 180 if needsQuant: 181 pt = mdl.QuantizeActivity(pt[:]) 182 trueRes = pt[-1] 183 if trueRes in picking: 184 nTrueActives += 1 185 if pred in picking: 186 res.append((pt[0], trueRes, pred, conf)) 187 return nTrueActives, res
188 189
190 -def AccumulateCounts(predictions, thresh=0, sortIt=1):
191 """ Accumulates the data for the enrichment plot for a single model 192 193 **Arguments** 194 195 - predictions: a list of 3-tuples (as returned by _ScreenModels_) 196 197 - thresh: a threshold for the confidence level. Anything below 198 this threshold will not be considered 199 200 - sortIt: toggles sorting on confidence levels 201 202 203 **Returns** 204 205 - a list of 3-tuples: 206 207 - the id of the active picked here 208 209 - num actives found so far 210 211 - number of picks made so far 212 213 """ 214 if sortIt: 215 predictions.sort(lambda x, y: cmp(y[3], x[3])) 216 res = [] 217 nCorrect = 0 218 nPts = 0 219 for i in range(len(predictions)): 220 ID, real, pred, conf = predictions[i] 221 if conf > thresh: 222 if pred == real: 223 nCorrect += 1 224 nPts += 1 225 res.append((ID, nCorrect, nPts)) 226 227 return res
228 229
230 -def MakePlot(details, final, counts, pickVects, nModels, nTrueActs=-1):
231 if not hasattr(details, 'plotFile') or not details.plotFile: 232 return 233 234 dataFileName = '%s.dat' % (details.plotFile) 235 outF = open(dataFileName, 'w+') 236 i = 0 237 while i < len(final) and counts[i] != 0: 238 if nModels > 1: 239 _, sd = Stats.MeanAndDev(pickVects[i]) 240 confInterval = Stats.GetConfidenceInterval(sd, len(pickVects[i]), level=90) 241 outF.write('%d %f %f %d %f\n' % (i + 1, final[i][0] / counts[i], final[i][1] / counts[i], 242 counts[i], confInterval)) 243 else: 244 outF.write('%d %f %f %d\n' % (i + 1, final[i][0] / counts[i], final[i][1] / counts[i], 245 counts[i])) 246 i += 1 247 outF.close() 248 plotFileName = '%s.gnu' % (details.plotFile) 249 gnuF = open(plotFileName, 'w+') 250 gnuHdr = """# Generated by EnrichPlot.py version: %s 251 set size square 0.7 252 set xr [0:] 253 set data styl points 254 set ylab 'Num Correct Picks' 255 set xlab 'Num Picks' 256 set grid 257 set nokey 258 set term postscript enh color solid "Helvetica" 16 259 set term X 260 """ % (__VERSION_STRING) 261 print(gnuHdr, file=gnuF) 262 if nTrueActs > 0: 263 print('set yr [0:%d]' % nTrueActs, file=gnuF) 264 print('plot x with lines', file=gnuF) 265 if nModels > 1: 266 everyGap = i / 20 267 print('replot "%s" using 1:2 with lines,' % (dataFileName), end='', file=gnuF) 268 print('"%s" every %d using 1:2:5 with yerrorbars' % (dataFileName, everyGap), file=gnuF) 269 else: 270 print('replot "%s" with points' % (dataFileName), file=gnuF) 271 gnuF.close() 272 273 if hasattr(details, 'showPlot') and details.showPlot: 274 try: 275 from Gnuplot import Gnuplot 276 p = Gnuplot() 277 p('load "%s"' % (plotFileName)) 278 input('press return to continue...\n') 279 except Exception: 280 import traceback 281 traceback.print_exc()
282 283
284 -def Usage():
285 """ displays a usage message and exits """ 286 sys.stderr.write(__doc__) 287 sys.exit(-1)
288 289 290 if __name__ == '__main__': 291 import getopt 292 try: 293 args, extras = getopt.getopt(sys.argv[1:], 'd:t:a:N:p:cSTHF:v:', 294 ('thresh=', 'plotFile=', 'showPlot', 'pickleCol=', 'OOB', 'noSort', 295 'pickBase=', 'doROC', 'rocThresh=', 'enrich=')) 296 except Exception: 297 import traceback 298 traceback.print_exc() 299 Usage() 300 301 details = CompositeRun.CompositeRun() 302 CompositeRun.SetDefaults(details) 303 304 details.activeTgt = [1] 305 details.doTraining = 0 306 details.doHoldout = 0 307 details.dbTableName = '' 308 details.plotFile = '' 309 details.showPlot = 0 310 details.pickleCol = -1 311 details.errorEstimate = 0 312 details.sortIt = 1 313 details.pickBase = '' 314 details.doROC = 0 315 details.rocThresh = -1 316 for arg, val in args: 317 if arg == '-d': 318 details.dbName = val 319 if arg == '-t': 320 details.dbTableName = val 321 elif arg == '-a' or arg == '--enrich': 322 details.activeTgt = eval(val) 323 if not isinstance(details.activeTgt, (tuple, list)): 324 # if (type(details.activeTgt) not in (types.TupleType, types.ListType)): 325 details.activeTgt = (details.activeTgt, ) 326 327 elif arg == '--thresh': 328 details.threshold = float(val) 329 elif arg == '-N': 330 details.note = val 331 elif arg == '-p': 332 details.persistTblName = val 333 elif arg == '-S': 334 details.shuffleActivities = 1 335 elif arg == '-H': 336 details.doTraining = 0 337 details.doHoldout = 1 338 elif arg == '-T': 339 details.doTraining = 1 340 details.doHoldout = 0 341 elif arg == '-F': 342 details.filterFrac = float(val) 343 elif arg == '-v': 344 details.filterVal = float(val) 345 elif arg == '--plotFile': 346 details.plotFile = val 347 elif arg == '--showPlot': 348 details.showPlot = 1 349 elif arg == '--pickleCol': 350 details.pickleCol = int(val) - 1 351 elif arg == '--OOB': 352 details.errorEstimate = 1 353 elif arg == '--noSort': 354 details.sortIt = 0 355 elif arg == '--doROC': 356 details.doROC = 1 357 elif arg == '--rocThresh': 358 details.rocThresh = int(val) 359 elif arg == '--pickBase': 360 details.pickBase = val 361 362 if not details.dbName or not details.dbTableName: 363 Usage() 364 print('*******Please provide both the -d and -t arguments') 365 366 message('Building Data set\n') 367 dataSet = DataUtils.DBToData(details.dbName, details.dbTableName, user=RDConfig.defaultDBUser, 368 password=RDConfig.defaultDBPassword, pickleCol=details.pickleCol, 369 pickleClass=DataStructs.ExplicitBitVect) 370 371 descs = dataSet.GetVarNames() 372 nPts = dataSet.GetNPts() 373 message('npts: %d\n' % (nPts)) 374 final = numpy.zeros((nPts, 2), numpy.float) 375 counts = numpy.zeros(nPts, numpy.integer) 376 selPts = [None] * nPts 377 378 models = [] 379 if details.persistTblName: 380 conn = DbConnect(details.dbName, details.persistTblName) 381 message('-> Retrieving models from database') 382 curs = conn.GetCursor() 383 curs.execute("select model from %s where note='%s'" % (details.persistTblName, details.note)) 384 message('-> Reconstructing models') 385 try: 386 blob = curs.fetchone() 387 except Exception: 388 blob = None 389 while blob: 390 message(' Building model %d' % len(models)) 391 blob = blob[0] 392 try: 393 models.append(cPickle.loads(str(blob))) 394 except Exception: 395 import traceback 396 traceback.print_exc() 397 print('Model failed') 398 else: 399 message(' <-Done') 400 try: 401 blob = curs.fetchone() 402 except Exception: 403 blob = None 404 curs = None 405 else: 406 for modelName in extras: 407 try: 408 model = cPickle.load(open(modelName, 'rb')) 409 except Exception: 410 import traceback 411 print('problems with model %s:' % modelName) 412 traceback.print_exc() 413 else: 414 models.append(model) 415 nModels = len(models) 416 pickVects = {} 417 halfwayPts = [1e8] * len(models) 418 for whichModel, model in enumerate(models): 419 tmpD = dataSet 420 try: 421 seed = model._randomSeed 422 except AttributeError: 423 pass 424 else: 425 DataUtils.InitRandomNumbers(seed) 426 if details.shuffleActivities: 427 DataUtils.RandomizeActivities(tmpD, shuffle=1) 428 if hasattr(model, '_splitFrac') and (details.doHoldout or details.doTraining): 429 trainIdx, testIdx = SplitData.SplitIndices(tmpD.GetNPts(), model._splitFrac, silent=1) 430 if details.filterFrac != 0.0: 431 trainFilt, temp = DataUtils.FilterData(tmpD, details.filterVal, details.filterFrac, -1, 432 indicesToUse=trainIdx, indicesOnly=1) 433 testIdx += temp 434 trainIdx = trainFilt 435 if details.doTraining: 436 testIdx, trainIdx = trainIdx, testIdx 437 else: 438 testIdx = list(range(tmpD.GetNPts())) 439 440 message('screening %d examples' % (len(testIdx))) 441 nTrueActives, screenRes = ScreenModel(model, descs, tmpD, picking=details.activeTgt, 442 indices=testIdx, errorEstimate=details.errorEstimate) 443 message('accumulating') 444 runningCounts = AccumulateCounts(screenRes, sortIt=details.sortIt, thresh=details.threshold) 445 if details.pickBase: 446 pickFile = open('%s.%d.picks' % (details.pickBase, whichModel + 1), 'w+') 447 else: 448 pickFile = None 449 450 for i, entry in enumerate(runningCounts): 451 entry = runningCounts[i] 452 selPts[i] = entry[0] 453 final[i][0] += entry[1] 454 final[i][1] += entry[2] 455 v = pickVects.get(i, []) 456 v.append(entry[1]) 457 pickVects[i] = v 458 counts[i] += 1 459 if pickFile: 460 pickFile.write('%s\n' % (entry[0])) 461 if entry[1] >= nTrueActives / 2 and entry[2] < halfwayPts[whichModel]: 462 halfwayPts[whichModel] = entry[2] 463 message('Halfway point: %d\n' % halfwayPts[whichModel]) 464 465 if details.plotFile: 466 MakePlot(details, final, counts, pickVects, nModels, nTrueActs=nTrueActives) 467 else: 468 if nModels > 1: 469 print('#Index\tAvg_num_correct\tConf90Pct\tAvg_num_picked\tNum_picks\tlast_selection') 470 else: 471 print('#Index\tAvg_num_correct\tAvg_num_picked\tNum_picks\tlast_selection') 472 473 i = 0 474 while i < nPts and counts[i] != 0: 475 if nModels > 1: 476 mean, sd = Stats.MeanAndDev(pickVects[i]) 477 confInterval = Stats.GetConfidenceInterval(sd, len(pickVects[i]), level=90) 478 print('%d\t%f\t%f\t%f\t%d\t%s' % (i + 1, final[i][0] / counts[i], confInterval, 479 final[i][1] / counts[i], counts[i], str(selPts[i]))) 480 else: 481 print('%d\t%f\t%f\t%d\t%s' % (i + 1, final[i][0] / counts[i], final[i][1] / counts[i], 482 counts[i], str(selPts[i]))) 483 i += 1 484 485 mean, sd = Stats.MeanAndDev(halfwayPts) 486 print('Halfway point: %.2f(%.2f)' % (mean, sd)) 487