Package rdkit ::
Package ML ::
Module GrowComposite
|
|
1
2
3
4
5
6
7
8
9
10
11 """ command line utility for growing composite models
12
13 **Usage**
14
15 _GrowComposite [optional args] filename_
16
17 **Command Line Arguments**
18
19 - -n *count*: number of new models to build
20
21 - -C *pickle file name*: name of file containing composite upon which to build.
22
23 - --inNote *note*: note to be used in loading composite models from the database
24 for growing
25
26 - --balTable *table name*: table from which to take the original data set
27 (for balancing)
28
29 - --balWeight *weight*: (between 0 and 1) weighting factor for the new data
30 (for balancing). OR, *weight* can be a list of weights
31
32 - --balCnt *count*: number of individual models in the balanced composite
33 (for balancing)
34
35 - --balH: use only the holdout set from the original data set in the balancing
36 (for balancing)
37
38 - --balT: use only the training set from the original data set in the balancing
39 (for balancing)
40
41 - -S: shuffle the original data set
42 (for balancing)
43
44 - -r: randomize the activities of the original data set
45 (for balancing)
46
47 - -N *note*: note to be attached to the grown composite when it's saved in the
48 database
49
50 - --outNote *note*: equivalent to -N
51
52 - -o *filename*: name of an output file to hold the pickled composite after
53 it has been grown.
54 If multiple balance weights are used, the weights will be added to
55 the filenames.
56
57 - -L *limit*: provide an (integer) limit on individual model complexity
58
59 - -d *database name*: instead of reading the data from a QDAT file,
60 pull it from a database. In this case, the _filename_ argument
61 provides the name of the database table containing the data set.
62
63 - -p *tablename*: store persistence data in the database
64 in table *tablename*
65
66 - -l: locks the random number generator to give consistent sets
67 of training and hold-out data. This is primarily intended
68 for testing purposes.
69
70 - -g: be less greedy when training the models.
71
72 - -G *number*: force trees to be rooted at descriptor *number*.
73
74 - -D: show a detailed breakdown of the composite model performance
75 across the training and, when appropriate, hold-out sets.
76
77 - -t *threshold value*: use high-confidence predictions for the final
78 analysis of the hold-out data.
79
80 - -q *list string*: Add QuantTrees to the composite and use the list
81 specified in *list string* as the number of target quantization
82 bounds for each descriptor. Don't forget to include 0's at the
83 beginning and end of *list string* for the name and value fields.
84 For example, if there are 4 descriptors and you want 2 quant bounds
85 apiece, you would use _-q "[0,2,2,2,2,0]"_.
86 Two special cases:
87 1) If you would like to ignore a descriptor in the model building,
88 use '-1' for its number of quant bounds.
89 2) If you have integer valued data that should not be quantized
90 further, enter 0 for that descriptor.
91
92 - -V: print the version number and exit
93
94 """
95 from __future__ import print_function
96
97 import sys
98 import time
99
100 import numpy
101
102 from rdkit.Dbase.DbConnection import DbConnect
103 from rdkit.ML import CompositeRun
104 from rdkit.ML import ScreenComposite, BuildComposite
105 from rdkit.ML.Composite import AdjustComposite
106 from rdkit.ML.Data import DataUtils, SplitData
107 from rdkit.six.moves import cPickle
108
109 _runDetails = CompositeRun.CompositeRun()
110
111 __VERSION_STRING = "0.5.0"
112
113 _verbose = 1
114
115
117 """ emits messages to _sys.stdout_
118 override this in modules which import this one to redirect output
119
120 **Arguments**
121
122 - msg: the string to be displayed
123
124 """
125 if _verbose:
126 sys.stdout.write('%s\n' % (msg))
127
128
129 -def GrowIt(details, composite, progressCallback=None, saveIt=1, setDescNames=0, data=None):
130 """ does the actual work of building a composite model
131
132 **Arguments**
133
134 - details: a _CompositeRun.CompositeRun_ object containing details
135 (options, parameters, etc.) about the run
136
137 - composite: the composite model to grow
138
139 - progressCallback: (optional) a function which is called with a single
140 argument (the number of models built so far) after each model is built.
141
142 - saveIt: (optional) if this is nonzero, the resulting model will be pickled
143 and dumped to the filename specified in _details.outName_
144
145 - setDescNames: (optional) if nonzero, the composite's _SetInputOrder()_ method
146 will be called using the results of the data set's _GetVarNames()_ method;
147 it is assumed that the details object has a _descNames attribute which
148 is passed to the composites _SetDescriptorNames()_ method. Otherwise
149 (the default), _SetDescriptorNames()_ gets the results of _GetVarNames()_.
150
151 - data: (optional) the data set to be used. If this is not provided, the
152 data set described in details will be used.
153
154 **Returns**
155
156 the enlarged composite model
157
158
159 """
160 details.rundate = time.asctime()
161
162 if data is None:
163 fName = details.tableName.strip()
164 if details.outName == '':
165 details.outName = fName + '.pkl'
166 if details.dbName == '':
167 data = DataUtils.BuildQuantDataSet(fName)
168 elif details.qBounds != []:
169 details.tableName = fName
170 data = details.GetDataSet()
171 else:
172 data = DataUtils.DBToQuantData(
173 details.dbName, fName, quantName=details.qTableName, user=details.dbUser,
174 password=details.dbPassword)
175
176 seed = composite._randomSeed
177 DataUtils.InitRandomNumbers(seed)
178 if details.shuffleActivities == 1:
179 DataUtils.RandomizeActivities(data, shuffle=1, runDetails=details)
180 elif details.randomActivities == 1:
181 DataUtils.RandomizeActivities(data, shuffle=0, runDetails=details)
182
183 namedExamples = data.GetNamedData()
184 trainExamples = namedExamples
185 nExamples = len(trainExamples)
186 message('Training with %d examples' % (nExamples))
187 message('\t%d descriptors' % (len(trainExamples[0]) - 2))
188 nVars = data.GetNVars()
189 nPossibleVals = composite.nPossibleVals
190 attrs = list(range(1, nVars + 1))
191
192 if details.useTrees:
193 from rdkit.ML.DecTree import CrossValidate, PruneTree
194 if details.qBounds != []:
195 from rdkit.ML.DecTree import BuildQuantTree
196 builder = BuildQuantTree.QuantTreeBoot
197 else:
198 from rdkit.ML.DecTree import ID3
199 builder = ID3.ID3Boot
200 driver = CrossValidate.CrossValidationDriver
201 pruner = PruneTree.PruneTree
202
203 if setDescNames:
204 composite.SetInputOrder(data.GetVarNames())
205 composite.Grow(trainExamples, attrs, [0] + nPossibleVals, buildDriver=driver, pruner=pruner,
206 nTries=details.nModels, pruneIt=details.pruneIt, lessGreedy=details.lessGreedy,
207 needsQuantization=0, treeBuilder=builder, nQuantBounds=details.qBounds,
208 startAt=details.startAt, maxDepth=details.limitDepth,
209 progressCallback=progressCallback, silent=not _verbose)
210
211 else:
212 from rdkit.ML.Neural import CrossValidate
213 driver = CrossValidate.CrossValidationDriver
214 composite.Grow(trainExamples, attrs, [0] + nPossibleVals, nTries=details.nModels,
215 buildDriver=driver, needsQuantization=0)
216
217 composite.AverageErrors()
218 composite.SortModels()
219 modelList, counts, avgErrs = composite.GetAllData()
220 counts = numpy.array(counts)
221 avgErrs = numpy.array(avgErrs)
222 composite._varNames = data.GetVarNames()
223
224 for i in range(len(modelList)):
225 modelList[i].NameModel(composite._varNames)
226
227
228 weightedErrs = counts * avgErrs
229 averageErr = sum(weightedErrs) / sum(counts)
230 devs = (avgErrs - averageErr)
231 devs = devs * counts
232 devs = numpy.sqrt(devs * devs)
233 avgDev = sum(devs) / sum(counts)
234 if _verbose:
235 message('# Overall Average Error: %%% 5.2f, Average Deviation: %%% 6.2f' %
236 (100. * averageErr, 100. * avgDev))
237
238 if details.bayesModel:
239 composite.Train(trainExamples, verbose=0)
240
241 badExamples = []
242 if not details.detailedRes:
243 if _verbose:
244 message('Testing all examples')
245 wrong = BuildComposite.testall(composite, namedExamples, badExamples)
246 if _verbose:
247 message('%d examples (%% %5.2f) were misclassified' %
248 (len(wrong), 100. * float(len(wrong)) / float(len(namedExamples))))
249 _runDetails.overall_error = float(len(wrong)) / len(namedExamples)
250
251 if details.detailedRes:
252 if _verbose:
253 message('\nEntire data set:')
254 resTup = ScreenComposite.ShowVoteResults(
255 range(data.GetNPts()), data, composite, nPossibleVals[-1], details.threshold)
256 nGood, nBad, _, avgGood, avgBad, _, voteTab = resTup
257 nPts = len(namedExamples)
258 nClass = nGood + nBad
259 _runDetails.overall_error = float(nBad) / nClass
260 _runDetails.overall_correct_conf = avgGood
261 _runDetails.overall_incorrect_conf = avgBad
262 _runDetails.overall_result_matrix = repr(voteTab)
263 nRej = nClass - nPts
264 if nRej > 0:
265 _runDetails.overall_fraction_dropped = float(nRej) / nPts
266
267 return composite
268
269
271 res = []
272 if details.persistTblName and details.inNote:
273 conn = DbConnect(details.dbName, details.persistTblName)
274 mdls = conn.GetData(fields='MODEL', where="where note='%s'" % (details.inNote))
275 for row in mdls:
276 rawD = row[0]
277 res.append(cPickle.loads(str(rawD)))
278 elif details.composFileName:
279 res.append(cPickle.load(open(details.composFileName, 'rb')))
280 return res
281
282
284 """ balances the composite using the parameters provided in details
285
286 **Arguments**
287
288 - details a _CompositeRun.RunDetails_ object
289
290 - composite: the composite model to be balanced
291
292 - data1: (optional) if provided, this should be the
293 data set used to construct the original models
294
295 - data2: (optional) if provided, this should be the
296 data set used to construct the new individual models
297
298 """
299 if not details.balCnt or details.balCnt > len(composite):
300 return composite
301 message("Balancing Composite")
302
303
304
305
306
307 if data1 is None:
308 message("\tReading First Data Set")
309 fName = details.balTable.strip()
310 tmp = details.tableName
311 details.tableName = fName
312 dbName = details.dbName
313 details.dbName = details.balDb
314 data1 = details.GetDataSet()
315 details.tableName = tmp
316 details.dbName = dbName
317 if data1 is None:
318 return composite
319 details.splitFrac = composite._splitFrac
320 details.randomSeed = composite._randomSeed
321 DataUtils.InitRandomNumbers(details.randomSeed)
322 if details.shuffleActivities == 1:
323 DataUtils.RandomizeActivities(data1, shuffle=1, runDetails=details)
324 elif details.randomActivities == 1:
325 DataUtils.RandomizeActivities(data1, shuffle=0, runDetails=details)
326 namedExamples = data1.GetNamedData()
327 if details.balDoHoldout or details.balDoTrain:
328 trainIdx, testIdx = SplitData.SplitIndices(len(namedExamples), details.splitFrac, silent=1)
329 trainExamples = [namedExamples[x] for x in trainIdx]
330 testExamples = [namedExamples[x] for x in testIdx]
331 if details.filterFrac != 0.0:
332 trainIdx, temp = DataUtils.FilterData(trainExamples, details.filterVal, details.filterFrac,
333 -1, indicesOnly=1)
334 tmp = [trainExamples[x] for x in trainIdx]
335 testExamples += [trainExamples[x] for x in temp]
336 trainExamples = tmp
337 if details.balDoHoldout:
338 testExamples, trainExamples = trainExamples, testExamples
339 else:
340 trainExamples = namedExamples
341 dataSet1 = trainExamples
342 cols1 = [x.upper() for x in data1.GetVarNames()]
343 data1 = None
344
345
346
347
348 if data2 is None:
349 message("\tReading Second Data Set")
350 data2 = details.GetDataSet()
351 if data2 is None:
352 return composite
353 details.splitFrac = composite._splitFrac
354 details.randomSeed = composite._randomSeed
355 DataUtils.InitRandomNumbers(details.randomSeed)
356 if details.shuffleActivities == 1:
357 DataUtils.RandomizeActivities(data2, shuffle=1, runDetails=details)
358 elif details.randomActivities == 1:
359 DataUtils.RandomizeActivities(data2, shuffle=0, runDetails=details)
360 dataSet2 = data2.GetNamedData()
361 cols2 = [x.upper() for x in data2.GetVarNames()]
362 data2 = None
363
364
365 res = []
366 weights = details.balWeight
367 if not isinstance(weights, (tuple, list)):
368 weights = (weights, )
369 for weight in weights:
370 message("\tBalancing with Weight: %.4f" % (weight))
371 res.append(
372 AdjustComposite.BalanceComposite(composite, dataSet1, dataSet2, weight, details.balCnt,
373 names1=cols1, names2=cols2))
374 return res
375
376
378 """ prints the version number
379
380 """
381 print('This is GrowComposite.py version %s' % (__VERSION_STRING))
382 if includeArgs:
383 print('command line was:')
384 print(' '.join(sys.argv))
385
386
388 """ provides a list of arguments for when this is used from the command line
389
390 """
391 print(__doc__)
392 sys.exit(-1)
393
394
396 """ initializes a details object with default values
397
398 **Arguments**
399
400 - details: (optional) a _CompositeRun.CompositeRun_ object.
401 If this is not provided, the global _runDetails will be used.
402
403 **Returns**
404
405 the initialized _CompositeRun_ object.
406
407
408 """
409 if runDetails is None:
410 runDetails = _runDetails
411 return CompositeRun.SetDefaults(runDetails)
412
413
415 """ parses command line arguments and updates _runDetails_
416
417 **Arguments**
418
419 - runDetails: a _CompositeRun.CompositeRun_ object.
420
421 """
422 import getopt
423 args, extra = getopt.getopt(sys.argv[1:], 'P:o:n:p:b:sf:F:v:hlgd:rSTt:Q:q:DVG:L:C:N:',
424 ['inNote=',
425 'outNote=',
426 'balTable=',
427 'balWeight=',
428 'balCnt=',
429 'balH',
430 'balT',
431 'balDb=', ])
432 runDetails.inNote = ''
433 runDetails.composFileName = ''
434 runDetails.balTable = ''
435 runDetails.balWeight = (0.5, )
436 runDetails.balCnt = 0
437 runDetails.balDoHoldout = 0
438 runDetails.balDoTrain = 0
439 runDetails.balDb = ''
440 for arg, val in args:
441 if arg == '-n':
442 runDetails.nModels = int(val)
443 elif arg == '-C':
444 runDetails.composFileName = val
445 elif arg == '--balTable':
446 runDetails.balTable = val
447 elif arg == '--balWeight':
448 runDetails.balWeight = eval(val)
449 if not isinstance(runDetails.balWeight, (tuple, list)):
450 runDetails.balWeight = (runDetails.balWeight, )
451 elif arg == '--balCnt':
452 runDetails.balCnt = int(val)
453 elif arg == '--balH':
454 runDetails.balDoHoldout = 1
455 elif arg == '--balT':
456 runDetails.balDoTrain = 1
457 elif arg == '--balDb':
458 runDetails.balDb = val
459 elif arg == '--inNote':
460 runDetails.inNote = val
461 elif arg == '-N' or arg == '--outNote':
462 runDetails.note = val
463 elif arg == '-o':
464 runDetails.outName = val
465 elif arg == '-p':
466 runDetails.persistTblName = val
467 elif arg == '-r':
468 runDetails.randomActivities = 1
469 elif arg == '-S':
470 runDetails.shuffleActivities = 1
471 elif arg == '-h':
472 Usage()
473 elif arg == '-l':
474 runDetails.lockRandom = 1
475 elif arg == '-g':
476 runDetails.lessGreedy = 1
477 elif arg == '-G':
478 runDetails.startAt = int(val)
479 elif arg == '-d':
480 runDetails.dbName = val
481 elif arg == '-T':
482 runDetails.useTrees = 0
483 elif arg == '-t':
484 runDetails.threshold = float(val)
485 elif arg == '-D':
486 runDetails.detailedRes = 1
487 elif arg == '-L':
488 runDetails.limitDepth = int(val)
489 elif arg == '-q':
490 qBounds = eval(val)
491 assert isinstance(qBounds,
492 (tuple, list)), 'bad argument type for -q, specify a list as a string'
493 runDetails.qBoundCount = val
494 runDetails.qBounds = qBounds
495 elif arg == '-Q':
496 qBounds = eval(val)
497 assert type(qBounds) in [type([]), type(
498 ())], 'bad argument type for -Q, specify a list as a string'
499 runDetails.activityBounds = qBounds
500 runDetails.activityBoundsVals = val
501 elif arg == '-V':
502 ShowVersion()
503 sys.exit(0)
504 else:
505 print('bad argument:', arg, file=sys.stderr)
506 Usage()
507 runDetails.tableName = extra[0]
508 if not runDetails.balDb:
509 runDetails.balDb = runDetails.dbName
510
511
512 if __name__ == '__main__':
513 if len(sys.argv) < 2:
514 Usage()
515
516 _runDetails.cmd = ' '.join(sys.argv)
517 SetDefaults(_runDetails)
518 ParseArgs(_runDetails)
519
520 ShowVersion(includeArgs=1)
521
522 initModels = GetComposites(_runDetails)
523 nModels = len(initModels)
524 if nModels > 1:
525 for i in range(nModels):
526 sys.stderr.write(
527 '---------------------------------\n\tDoing %d of %d\n---------------------------------\n' %
528 (i + 1, nModels))
529 composite = GrowIt(_runDetails, initModels[i], setDescNames=1)
530 if _runDetails.balTable and _runDetails.balCnt:
531 composites = BalanceComposite(_runDetails, composite)
532 else:
533 composites = [composite]
534 for mdl in composites:
535 mdl.ClearModelExamples()
536 if _runDetails.outName:
537 nWeights = len(_runDetails.balWeight)
538 if nWeights == 1:
539 outName = _runDetails.outName
540 composites[0].Pickle(outName)
541 else:
542 for i in range(nWeights):
543 weight = int(100 * _runDetails.balWeight[i])
544 model = composites[i]
545 outName = '%s.%d.pkl' % (_runDetails.outName.split('.pkl')[0], weight)
546 model.Pickle(outName)
547 if _runDetails.persistTblName and _runDetails.dbName:
548 message('Updating results table %s:%s' % (_runDetails.dbName, _runDetails.persistTblName))
549 if (len(_runDetails.balWeight)) > 1:
550 message('WARNING: updating results table with models having different weights')
551
552 for i in range(len(composites)):
553 _runDetails.model = cPickle.dumps(composites[i])
554 _runDetails.Store(db=_runDetails.dbName, table=_runDetails.persistTblName)
555 elif nModels == 1:
556 composite = GrowIt(_runDetails, initModels[0], setDescNames=1)
557 if _runDetails.balTable and _runDetails.balCnt:
558 composites = BalanceComposite(_runDetails, composite)
559 else:
560 composites = [composite]
561 for mdl in composites:
562 mdl.ClearModelExamples()
563 if _runDetails.outName:
564 nWeights = len(_runDetails.balWeight)
565 if nWeights == 1:
566 outName = _runDetails.outName
567 composites[0].Pickle(outName)
568 else:
569 for i in range(nWeights):
570 weight = int(100 * _runDetails.balWeight[i])
571 model = composites[i]
572 outName = '%s.%d.pkl' % (_runDetails.outName.split('.pkl')[0], weight)
573 model.Pickle(outName)
574 if _runDetails.persistTblName and _runDetails.dbName:
575 message('Updating results table %s:%s' % (_runDetails.dbName, _runDetails.persistTblName))
576 if (len(composites)) > 1:
577 message('WARNING: updating results table with models having different weights')
578 for i in range(len(composites)):
579 _runDetails.model = cPickle.dumps(composites[i])
580 _runDetails.Store(db=_runDetails.dbName, table=_runDetails.persistTblName)
581 else:
582 message("No models found")
583