Package rdkit ::
Package ML ::
Module AnalyzeComposite
|
|
1
2
3
4
5
6
7
8
9
10
11 """ command line utility to report on the contributions of descriptors to
12 tree-based composite models
13
14 Usage: AnalyzeComposite [optional args] <models>
15
16 <models>: file name(s) of pickled composite model(s)
17 (this is the name of the db table if using a database)
18
19 Optional Arguments:
20
21 -n number: the number of levels of each model to consider
22
23 -d dbname: the database from which to read the models
24
25 -N Note: the note string to search for to pull models from the database
26
27 -v: be verbose whilst screening
28 """
29 from __future__ import print_function
30
31 import sys
32
33 import numpy
34
35 from rdkit.Dbase.DbConnection import DbConnect
36 from rdkit.ML import ScreenComposite
37 from rdkit.ML.Data import Stats
38 from rdkit.ML.DecTree import TreeUtils, Tree
39 from rdkit.six.moves import cPickle
40
41
42 __VERSION_STRING = "2.2.0"
43
44
45 -def ProcessIt(composites, nToConsider=3, verbose=0):
46 composite = composites[0]
47 nComposites = len(composites)
48 ns = composite.GetDescriptorNames()
49
50 if len(ns) > 2:
51 globalRes = {}
52
53 nDone = 1
54 descNames = {}
55 for composite in composites:
56 if verbose > 0:
57 print('#------------------------------------')
58 print('Doing: ', nDone)
59 nModels = len(composite)
60 nDone += 1
61 res = {}
62 for i in range(len(composite)):
63 model = composite.GetModel(i)
64 if isinstance(model, Tree.TreeNode):
65 levels = TreeUtils.CollectLabelLevels(model, {}, 0, nToConsider)
66 TreeUtils.CollectDescriptorNames(model, descNames, 0, nToConsider)
67 for descId in levels.keys():
68 v = res.get(descId, numpy.zeros(nToConsider, numpy.float))
69 v[levels[descId]] += 1. / nModels
70 res[descId] = v
71 for k in res:
72 v = globalRes.get(k, numpy.zeros(nToConsider, numpy.float))
73 v += res[k] / nComposites
74 globalRes[k] = v
75 if verbose > 0:
76 for k in res.keys():
77 name = descNames[k]
78 strRes = ', '.join(['%4.2f' % x for x in res[k]])
79 print('%s,%s,%5.4f' % (name, strRes, sum(res[k])))
80
81 print()
82
83 if verbose >= 0:
84 print('# Average Descriptor Positions')
85 retVal = []
86 for k in globalRes:
87 name = descNames[k]
88 if verbose >= 0:
89 strRes = ', '.join(['%4.2f' % x for x in globalRes[k]])
90 print('%s,%s,%5.4f' % (name, strRes, sum(globalRes[k])))
91 tmp = [name]
92 tmp.extend(globalRes[k])
93 tmp.append(sum(globalRes[k]))
94 retVal.append(tmp)
95 if verbose >= 0:
96 print()
97 else:
98 retVal = []
99 return retVal
100
101
103 fields = ('overall_error,holdout_error,overall_result_matrix,' +
104 'holdout_result_matrix,overall_correct_conf,overall_incorrect_conf,' +
105 'holdout_correct_conf,holdout_incorrect_conf')
106 try:
107 data = conn.GetData(fields=fields, where=where)
108 except Exception:
109 import traceback
110 traceback.print_exc()
111 return None
112 nPts = len(data)
113 if not nPts:
114 sys.stderr.write('no runs found\n')
115 return None
116 overall = numpy.zeros(nPts, numpy.float)
117 overallEnrich = numpy.zeros(nPts, numpy.float)
118 oCorConf = 0.0
119 oInCorConf = 0.0
120 holdout = numpy.zeros(nPts, numpy.float)
121 holdoutEnrich = numpy.zeros(nPts, numpy.float)
122 hCorConf = 0.0
123 hInCorConf = 0.0
124 overallMatrix = None
125 holdoutMatrix = None
126 for i in range(nPts):
127 if data[i][0] is not None:
128 overall[i] = data[i][0]
129 oCorConf += data[i][4]
130 oInCorConf += data[i][5]
131 if data[i][1] is not None:
132 holdout[i] = data[i][1]
133 haveHoldout = 1
134 else:
135 haveHoldout = 0
136 tmpOverall = 1. * eval(data[i][2])
137 if enrich >= 0:
138 overallEnrich[i] = ScreenComposite.CalcEnrichment(tmpOverall, tgt=enrich)
139 if haveHoldout:
140 tmpHoldout = 1. * eval(data[i][3])
141 if enrich >= 0:
142 holdoutEnrich[i] = ScreenComposite.CalcEnrichment(tmpHoldout, tgt=enrich)
143 if overallMatrix is None:
144 if data[i][2] is not None:
145 overallMatrix = tmpOverall
146 if haveHoldout and data[i][3] is not None:
147 holdoutMatrix = tmpHoldout
148 else:
149 overallMatrix += tmpOverall
150 if haveHoldout:
151 holdoutMatrix += tmpHoldout
152 if haveHoldout:
153 hCorConf += data[i][6]
154 hInCorConf += data[i][7]
155
156 avgOverall = sum(overall) / nPts
157 oCorConf /= nPts
158 oInCorConf /= nPts
159 overallMatrix /= nPts
160 oSort = numpy.argsort(overall)
161 oMin = overall[oSort[0]]
162 overall -= avgOverall
163 devOverall = numpy.sqrt(sum(overall**2) / (nPts - 1))
164 res = {}
165 res['oAvg'] = 100 * avgOverall
166 res['oDev'] = 100 * devOverall
167 res['oCorrectConf'] = 100 * oCorConf
168 res['oIncorrectConf'] = 100 * oInCorConf
169 res['oResultMat'] = overallMatrix
170 res['oBestIdx'] = oSort[0]
171 res['oBestErr'] = 100 * oMin
172
173 if enrich >= 0:
174 mean, dev = Stats.MeanAndDev(overallEnrich)
175 res['oAvgEnrich'] = mean
176 res['oDevEnrich'] = dev
177
178 if haveHoldout:
179 avgHoldout = sum(holdout) / nPts
180 hCorConf /= nPts
181 hInCorConf /= nPts
182 holdoutMatrix /= nPts
183 hSort = numpy.argsort(holdout)
184 hMin = holdout[hSort[0]]
185 holdout -= avgHoldout
186 devHoldout = numpy.sqrt(sum(holdout**2) / (nPts - 1))
187 res['hAvg'] = 100 * avgHoldout
188 res['hDev'] = 100 * devHoldout
189 res['hCorrectConf'] = 100 * hCorConf
190 res['hIncorrectConf'] = 100 * hInCorConf
191 res['hResultMat'] = holdoutMatrix
192 res['hBestIdx'] = hSort[0]
193 res['hBestErr'] = 100 * hMin
194 if enrich >= 0:
195 mean, dev = Stats.MeanAndDev(holdoutEnrich)
196 res['hAvgEnrich'] = mean
197 res['hDevEnrich'] = dev
198 return res
199
200
202 statD = statD.copy()
203 statD['oBestIdx'] = statD['oBestIdx'] + 1
204 txt = """
205 # Error Statistics:
206 \tOverall: %(oAvg)6.3f%% (%(oDev)6.3f) %(oCorrectConf)4.1f/%(oIncorrectConf)4.1f
207 \t\tBest: %(oBestIdx)d %(oBestErr)6.3f%%""" % (statD)
208 if 'hAvg' in statD:
209 statD['hBestIdx'] = statD['hBestIdx'] + 1
210 txt += """
211 \tHoldout: %(hAvg)6.3f%% (%(hDev)6.3f) %(hCorrectConf)4.1f/%(hIncorrectConf)4.1f
212 \t\tBest: %(hBestIdx)d %(hBestErr)6.3f%%
213 """ % (statD)
214 print(txt)
215 print()
216 print('# Results matrices:')
217 print('\tOverall:')
218 tmp = numpy.transpose(statD['oResultMat'])
219 colCounts = sum(tmp)
220 rowCounts = sum(tmp, 1)
221 for i in range(len(tmp)):
222 if rowCounts[i] == 0:
223 rowCounts[i] = 1
224 row = tmp[i]
225 print('\t\t', end='')
226 for j in range(len(row)):
227 print('% 6.2f' % row[j], end='')
228 print('\t| % 4.2f' % (100. * tmp[i, i] / rowCounts[i]))
229 print('\t\t', end='')
230 for i in range(len(tmp)):
231 print('------', end='')
232 print()
233 print('\t\t', end='')
234 for i in range(len(tmp)):
235 if colCounts[i] == 0:
236 colCounts[i] = 1
237 print('% 6.2f' % (100. * tmp[i, i] / colCounts[i]), end='')
238 print()
239 if enrich > -1 and 'oAvgEnrich' in statD:
240 print('\t\tEnrich(%d): %.3f (%.3f)' % (enrich, statD['oAvgEnrich'], statD['oDevEnrich']))
241
242 if 'hResultMat' in statD:
243 print('\tHoldout:')
244 tmp = numpy.transpose(statD['hResultMat'])
245 colCounts = sum(tmp)
246 rowCounts = sum(tmp, 1)
247 for i in range(len(tmp)):
248 if rowCounts[i] == 0:
249 rowCounts[i] = 1
250 row = tmp[i]
251 print('\t\t', end='')
252 for j in range(len(row)):
253 print('% 6.2f' % row[j], end='')
254 print('\t| % 4.2f' % (100. * tmp[i, i] / rowCounts[i]))
255 print('\t\t', end='')
256 for i in range(len(tmp)):
257 print('------', end='')
258 print()
259 print('\t\t', end='')
260 for i in range(len(tmp)):
261 if colCounts[i] == 0:
262 colCounts[i] = 1
263 print('% 6.2f' % (100. * tmp[i, i] / colCounts[i]), end='')
264 print()
265 if enrich > -1 and 'hAvgEnrich' in statD:
266 print('\t\tEnrich(%d): %.3f (%.3f)' % (enrich, statD['hAvgEnrich'], statD['hDevEnrich']))
267
268 return
269
270
272 print(__doc__)
273 sys.exit(-1)
274
275
276 if __name__ == "__main__":
277 import getopt
278 try:
279 args, extras = getopt.getopt(sys.argv[1:], 'n:d:N:vX', ('skip',
280 'enrich=', ))
281 except Exception:
282 Usage()
283
284 count = 3
285 db = None
286 note = ''
287 verbose = 0
288 skip = 0
289 enrich = 1
290 for arg, val in args:
291 if arg == '-n':
292 count = int(val) + 1
293 elif arg == '-d':
294 db = val
295 elif arg == '-N':
296 note = val
297 elif arg == '-v':
298 verbose = 1
299 elif arg == '--skip':
300 skip = 1
301 elif arg == '--enrich':
302 enrich = int(val)
303 composites = []
304 if db is None:
305 for arg in extras:
306 composite = cPickle.load(open(arg, 'rb'))
307 composites.append(composite)
308 else:
309 tbl = extras[0]
310 conn = DbConnect(db, tbl)
311 if note:
312 where = "where note='%s'" % (note)
313 else:
314 where = ''
315 if not skip:
316 pkls = conn.GetData(fields='model', where=where)
317 composites = []
318 for pkl in pkls:
319 pkl = str(pkl[0])
320 comp = cPickle.loads(pkl)
321 composites.append(comp)
322
323 if len(composites):
324 ProcessIt(composites, count, verbose=verbose)
325 elif not skip:
326 print('ERROR: no composite models found')
327 sys.exit(-1)
328
329 if db:
330 res = ErrorStats(conn, where, enrich=enrich)
331 if res:
332 ShowStats(res)
333