1
2
3
4
5 from __future__ import print_function
6
7 import numpy
8
9 from rdkit import Chem, Geometry
10 from rdkit import RDLogger
11 from rdkit.Chem.Subshape import SubshapeObjects
12 from rdkit.Numerics import Alignment
13
14 logger = RDLogger.logger()
15
16
25
26
28 for i in range(len(pts)):
29 if orderedTraversal:
30 jStart = i + 1
31 else:
32 jStart = 0
33 for j in range(jStart, len(pts)):
34 if j == i:
35 continue
36 if orderedTraversal:
37 kStart = j + 1
38 else:
39 kStart = 0
40 for k in range(j + 1, len(pts)):
41 if k == i or k == j:
42 continue
43 yield (i, j, k)
44
45
49
50
63
64
67 """ clusters a set of alignments and returns the cluster centroid """
68 from rdkit.ML.Cluster import Butina
69 dists = []
70 for i in range(len(alignments)):
71 TransformMol(mol, alignments[i].transform, newConfId=tempConfId)
72 shapeI = builder.GenerateSubshapeShape(mol, tempConfId, addSkeleton=False)
73 for j in range(i):
74 TransformMol(mol, alignments[j].transform, newConfId=tempConfId + 1)
75 shapeJ = builder.GenerateSubshapeShape(mol, tempConfId + 1, addSkeleton=False)
76 d = GetShapeShapeDistance(shapeI, shapeJ, distMetric)
77 dists.append(d)
78 mol.RemoveConformer(tempConfId + 1)
79 mol.RemoveConformer(tempConfId)
80 clusts = Butina.ClusterData(dists, len(alignments), neighborTol, isDistData=True)
81 res = [alignments[x[0]] for x in clusts]
82 return res
83
84
98
99
101 triangleRMSTol = 1.0
102 distMetric = SubshapeDistanceMetric.PROTRUDE
103 shapeDistTol = 0.2
104 numFeatThresh = 3
105 dirThresh = 2.6
106 edgeTol = 6.0
107
108
109 coarseGridToleranceMult = 1.0
110 medGridToleranceMult = 1.0
111
113 """ this is a generator function returning the possible triangle
114 matches between the two shapes
115 """
116 ssdTol = (self.triangleRMSTol**2) * 9
117 tgtPts = target.skelPts
118 queryPts = query.skelPts
119 tgtLs = {}
120 for i in range(len(tgtPts)):
121 for j in range(i + 1, len(tgtPts)):
122 l2 = (tgtPts[i].location - tgtPts[j].location).LengthSq()
123 tgtLs[(i, j)] = l2
124 queryLs = {}
125 for i in range(len(queryPts)):
126 for j in range(i + 1, len(queryPts)):
127 l2 = (queryPts[i].location - queryPts[j].location).LengthSq()
128 queryLs[(i, j)] = l2
129 compatEdges = {}
130 tol2 = self.edgeTol * self.edgeTol
131 for tk, tv in tgtLs.items():
132 for qk, qv in queryLs.items():
133 if abs(tv - qv) < tol2:
134 compatEdges[(tk, qk)] = 1
135 seqNo = 0
136 for tgtTri in _getAllTriangles(tgtPts, orderedTraversal=True):
137 tgtLocs = [tgtPts[x].location for x in tgtTri]
138 for queryTri in _getAllTriangles(queryPts, orderedTraversal=False):
139 if ((tgtTri[0], tgtTri[1]), (queryTri[0], queryTri[1])) in compatEdges and \
140 ((tgtTri[0], tgtTri[2]), (queryTri[0], queryTri[2])) in compatEdges and \
141 ((tgtTri[1], tgtTri[2]), (queryTri[1], queryTri[2])) in compatEdges:
142 queryLocs = [queryPts[x].location for x in queryTri]
143 ssd, tf = Alignment.GetAlignmentTransform(tgtLocs, queryLocs)
144 if ssd <= ssdTol:
145 alg = SubshapeAlignment()
146 alg.transform = tf
147 alg.triangleSSD = ssd
148 alg.targetTri = tgtTri
149 alg.queryTri = queryTri
150 alg._seqNo = seqNo
151 seqNo += 1
152 yield alg
153
155 nMatched = 0
156 for i in range(3):
157 tgtFeats = targetPts[alignment.targetTri[i]].molFeatures
158 qFeats = queryPts[alignment.queryTri[i]].molFeatures
159 if not tgtFeats and not qFeats:
160 nMatched += 1
161 else:
162 for jFeat in tgtFeats:
163 if jFeat in qFeats:
164 nMatched += 1
165 break
166 if nMatched >= self.numFeatThresh:
167 break
168 return nMatched >= self.numFeatThresh
169
171 i = 0
172 targetPts = target.skelPts
173 queryPts = query.skelPts
174 while i < len(alignments):
175 alg = alignments[i]
176 if not self._checkMatchFeatures(targetPts, queryPts, alg):
177 if pruneStats is not None:
178 pruneStats['features'] = pruneStats.get('features', 0) + 1
179 del alignments[i]
180 else:
181 i += 1
182
184 dot = 0.0
185 for i in range(3):
186 tgtPt = targetPts[alignment.targetTri[i]]
187 queryPt = queryPts[alignment.queryTri[i]]
188 qv = queryPt.shapeDirs[0]
189 tv = tgtPt.shapeDirs[0]
190 rotV = [0.0] * 3
191 rotV[0] = alignment.transform[0, 0] * qv[0] + alignment.transform[0, 1] * qv[
192 1] + alignment.transform[0, 2] * qv[2]
193 rotV[1] = alignment.transform[1, 0] * qv[0] + alignment.transform[1, 1] * qv[
194 1] + alignment.transform[1, 2] * qv[2]
195 rotV[2] = alignment.transform[2, 0] * qv[0] + alignment.transform[2, 1] * qv[
196 1] + alignment.transform[2, 2] * qv[2]
197 dot += abs(rotV[0] * tv[0] + rotV[1] * tv[1] + rotV[2] * tv[2])
198 if dot >= self.dirThresh:
199
200 break
201 alignment.dirMatch = dot
202 return dot >= self.dirThresh
203
205 i = 0
206 tgtPts = target.skelPts
207 queryPts = query.skelPts
208 while i < len(alignments):
209 if not self._checkMatchDirections(tgtPts, queryPts, alignments[i]):
210 if pruneStats is not None:
211 pruneStats['direction'] = pruneStats.get('direction', 0) + 1
212 del alignments[i]
213 else:
214 i += 1
215
227
228 - def _checkMatchShape(self, targetMol, target, queryMol, query, alignment, builder, targetConf,
229 queryConf, pruneStats=None, tConfId=1001):
230 matchOk = True
231 TransformMol(queryMol, alignment.transform, confId=queryConf, newConfId=tConfId)
232 oSpace = builder.gridSpacing
233 builder.gridSpacing = oSpace * 2
234 coarseGrid = builder.GenerateSubshapeShape(queryMol, tConfId, addSkeleton=False)
235 d = GetShapeShapeDistance(coarseGrid, target.coarseGrid, self.distMetric)
236 if d > self.shapeDistTol * self.coarseGridToleranceMult:
237 matchOk = False
238 if pruneStats is not None:
239 pruneStats['coarseGrid'] = pruneStats.get('coarseGrid', 0) + 1
240 else:
241 builder.gridSpacing = oSpace * 1.5
242 medGrid = builder.GenerateSubshapeShape(queryMol, tConfId, addSkeleton=False)
243 d = GetShapeShapeDistance(medGrid, target.medGrid, self.distMetric)
244 if d > self.shapeDistTol * self.medGridToleranceMult:
245 matchOk = False
246 if pruneStats is not None:
247 pruneStats['medGrid'] = pruneStats.get('medGrid', 0) + 1
248 else:
249 builder.gridSpacing = oSpace
250 fineGrid = builder.GenerateSubshapeShape(queryMol, tConfId, addSkeleton=False)
251 d = GetShapeShapeDistance(fineGrid, target, self.distMetric)
252 if d > self.shapeDistTol:
253 matchOk = False
254 if pruneStats is not None:
255 pruneStats['fineGrid'] = pruneStats.get('fineGrid', 0) + 1
256 alignment.shapeDist = d
257 queryMol.RemoveConformer(tConfId)
258 builder.gridSpacing = oSpace
259 return matchOk
260
261 - def PruneMatchesUsingShape(self, targetMol, target, queryMol, query, builder, alignments,
262 tgtConf=-1, queryConf=-1, pruneStats=None):
263 if not hasattr(target, 'medGrid'):
264 self._addCoarseAndMediumGrids(targetMol, target, tgtConf, builder)
265
266 logger.info("Shape-based Pruning")
267 i = 0
268 nOrig = len(alignments)
269 nDone = 0
270 while i < len(alignments):
271 alg = alignments[i]
272 nDone += 1
273 if not nDone % 100:
274 nLeft = len(alignments)
275 logger.info(' processed %d of %d. %d alignments remain' % ((nDone, nOrig, nLeft)))
276 if not self._checkMatchShape(targetMol, target, queryMol, query, alg, builder,
277 targetConf=tgtConf, queryConf=queryConf, pruneStats=pruneStats):
278 del alignments[i]
279 else:
280 i += 1
281
282 - def GetSubshapeAlignments(self, targetMol, target, queryMol, query, builder, tgtConf=-1,
283 queryConf=-1, pruneStats=None):
284 import time
285 if pruneStats is None:
286 pruneStats = {}
287 logger.info("Generating triangle matches")
288 t1 = time.time()
289 res = [x for x in self.GetTriangleMatches(target, query)]
290 t2 = time.time()
291 logger.info("Got %d possible alignments in %.1f seconds" % (len(res), t2 - t1))
292 pruneStats['gtm_time'] = t2 - t1
293 if builder.featFactory:
294 logger.info("Doing feature pruning")
295 t1 = time.time()
296 self.PruneMatchesUsingFeatures(target, query, res, pruneStats=pruneStats)
297 t2 = time.time()
298 pruneStats['feats_time'] = t2 - t1
299 logger.info("%d possible alignments remain. (%.1f seconds required)" % (len(res), t2 - t1))
300 logger.info("Doing direction pruning")
301 t1 = time.time()
302 self.PruneMatchesUsingDirection(target, query, res, pruneStats=pruneStats)
303 t2 = time.time()
304 pruneStats['direction_time'] = t2 - t1
305 logger.info("%d possible alignments remain. (%.1f seconds required)" % (len(res), t2 - t1))
306 t1 = time.time()
307 self.PruneMatchesUsingShape(targetMol, target, queryMol, query, builder, res, tgtConf=tgtConf,
308 queryConf=queryConf, pruneStats=pruneStats)
309 t2 = time.time()
310 pruneStats['shape_time'] = t2 - t1
311 return res
312
313 - def __call__(self, targetMol, target, queryMol, query, builder, tgtConf=-1, queryConf=-1,
314 pruneStats=None):
315 for alignment in self.GetTriangleMatches(target, query):
316 if (not self._checkMatchFeatures(target.skelPts, query.skelPts, alignment) and
317 builder.featFactory):
318 if pruneStats is not None:
319 pruneStats['features'] = pruneStats.get('features', 0) + 1
320 continue
321 if not self._checkMatchDirections(target.skelPts, query.skelPts, alignment):
322 if pruneStats is not None:
323 pruneStats['direction'] = pruneStats.get('direction', 0) + 1
324 continue
325
326 if not hasattr(target, 'medGrid'):
327 self._addCoarseAndMediumGrids(targetMol, target, tgtConf, builder)
328
329 if not self._checkMatchShape(targetMol, target, queryMol, query, alignment, builder,
330 targetConf=tgtConf, queryConf=queryConf, pruneStats=pruneStats):
331 continue
332
333 yield alignment
334
335
336 if __name__ == '__main__':
337 from rdkit.six.moves import cPickle
338 from rdkit.Chem.PyMol import MolViewer
339 with open('target.pkl', 'rb') as f:
340 tgtMol, tgtShape = cPickle.load(f)
341 with open('query.pkl', 'rb') as f:
342 queryMol, queryShape = cPickle.load(f)
343 with open('builder.pkl', 'rb') as f:
344 builder = cPickle.load(f)
345 aligner = SubshapeAligner()
346 algs = aligner.GetSubshapeAlignments(tgtMol, tgtShape, queryMol, queryShape, builder)
347 print(len(algs))
348
349 v = MolViewer()
350 v.ShowMol(tgtMol, name='Target', showOnly=True)
351 v.ShowMol(queryMol, name='Query', showOnly=False)
352 SubshapeObjects.DisplaySubshape(v, tgtShape, 'target_shape', color=(.8, .2, .2))
353 SubshapeObjects.DisplaySubshape(v, queryShape, 'query_shape', color=(.2, .2, .8))
354