Package rdkit :: Package Chem :: Package Subshape :: Module SubshapeAligner
[hide private]
[frames] | no frames]

Source Code for Module rdkit.Chem.Subshape.SubshapeAligner

  1  # 
  2  # Copyright (C) 2007-2008 by Greg Landrum 
  3  #  All rights reserved 
  4  # 
  5  from __future__ import print_function 
  6   
  7  import numpy 
  8   
  9  from rdkit import Chem, Geometry 
 10  from rdkit import RDLogger 
 11  from rdkit.Chem.Subshape import SubshapeObjects 
 12  from rdkit.Numerics import Alignment 
 13   
 14  logger = RDLogger.logger() 
 15   
 16   
17 -class SubshapeAlignment(object):
18 transform = None 19 triangleSSD = None 20 targetTri = None 21 queryTri = None 22 alignedConfId = -1 23 dirMatch = 0.0 24 shapeDist = 0.0
25 26
27 -def _getAllTriangles(pts, orderedTraversal=False):
28 for i in range(len(pts)): 29 if orderedTraversal: 30 jStart = i + 1 31 else: 32 jStart = 0 33 for j in range(jStart, len(pts)): 34 if j == i: 35 continue 36 if orderedTraversal: 37 kStart = j + 1 38 else: 39 kStart = 0 40 for k in range(j + 1, len(pts)): 41 if k == i or k == j: 42 continue 43 yield (i, j, k)
44 45
46 -class SubshapeDistanceMetric(object):
47 TANIMOTO = 0 48 PROTRUDE = 1
49 50
51 -def GetShapeShapeDistance(s1, s2, distMetric):
52 """ returns the distance between two shapes according to the provided metric """ 53 if distMetric == SubshapeDistanceMetric.PROTRUDE: 54 # print(s1.grid.GetOccupancyVect().GetTotalVal(),s2.grid.GetOccupancyVect().GetTotalVal()) 55 if s1.grid.GetOccupancyVect().GetTotalVal() < s2.grid.GetOccupancyVect().GetTotalVal(): 56 d = Geometry.ProtrudeDistance(s1.grid, s2.grid) 57 # print(d) 58 else: 59 d = Geometry.ProtrudeDistance(s2.grid, s1.grid) 60 else: 61 d = Geometry.TanimotoDistance(s1.grid, s2.grid) 62 return d
63 64
65 -def ClusterAlignments(mol, alignments, builder, neighborTol=0.1, 66 distMetric=SubshapeDistanceMetric.PROTRUDE, tempConfId=1001):
67 """ clusters a set of alignments and returns the cluster centroid """ 68 from rdkit.ML.Cluster import Butina 69 dists = [] 70 for i in range(len(alignments)): 71 TransformMol(mol, alignments[i].transform, newConfId=tempConfId) 72 shapeI = builder.GenerateSubshapeShape(mol, tempConfId, addSkeleton=False) 73 for j in range(i): 74 TransformMol(mol, alignments[j].transform, newConfId=tempConfId + 1) 75 shapeJ = builder.GenerateSubshapeShape(mol, tempConfId + 1, addSkeleton=False) 76 d = GetShapeShapeDistance(shapeI, shapeJ, distMetric) 77 dists.append(d) 78 mol.RemoveConformer(tempConfId + 1) 79 mol.RemoveConformer(tempConfId) 80 clusts = Butina.ClusterData(dists, len(alignments), neighborTol, isDistData=True) 81 res = [alignments[x[0]] for x in clusts] 82 return res
83 84
85 -def TransformMol(mol, tform, confId=-1, newConfId=100):
86 """ Applies the transformation to a molecule and sets it up with a single conformer """ 87 newConf = Chem.Conformer() 88 newConf.SetId(0) 89 refConf = mol.GetConformer(confId) 90 for i in range(refConf.GetNumAtoms()): 91 pos = list(refConf.GetAtomPosition(i)) 92 pos.append(1.0) 93 newPos = numpy.dot(tform, numpy.array(pos)) 94 newConf.SetAtomPosition(i, list(newPos)[:3]) 95 newConf.SetId(newConfId) 96 mol.RemoveConformer(newConfId) 97 mol.AddConformer(newConf, assignId=False)
98 99
100 -class SubshapeAligner(object):
101 triangleRMSTol = 1.0 102 distMetric = SubshapeDistanceMetric.PROTRUDE 103 shapeDistTol = 0.2 104 numFeatThresh = 3 105 dirThresh = 2.6 106 edgeTol = 6.0 107 # coarseGridToleranceMult=1.5 108 # medGridToleranceMult=1.25 109 coarseGridToleranceMult = 1.0 110 medGridToleranceMult = 1.0 111
112 - def GetTriangleMatches(self, target, query):
113 """ this is a generator function returning the possible triangle 114 matches between the two shapes 115 """ 116 ssdTol = (self.triangleRMSTol**2) * 9 117 tgtPts = target.skelPts 118 queryPts = query.skelPts 119 tgtLs = {} 120 for i in range(len(tgtPts)): 121 for j in range(i + 1, len(tgtPts)): 122 l2 = (tgtPts[i].location - tgtPts[j].location).LengthSq() 123 tgtLs[(i, j)] = l2 124 queryLs = {} 125 for i in range(len(queryPts)): 126 for j in range(i + 1, len(queryPts)): 127 l2 = (queryPts[i].location - queryPts[j].location).LengthSq() 128 queryLs[(i, j)] = l2 129 compatEdges = {} 130 tol2 = self.edgeTol * self.edgeTol 131 for tk, tv in tgtLs.items(): 132 for qk, qv in queryLs.items(): 133 if abs(tv - qv) < tol2: 134 compatEdges[(tk, qk)] = 1 135 seqNo = 0 136 for tgtTri in _getAllTriangles(tgtPts, orderedTraversal=True): 137 tgtLocs = [tgtPts[x].location for x in tgtTri] 138 for queryTri in _getAllTriangles(queryPts, orderedTraversal=False): 139 if ((tgtTri[0], tgtTri[1]), (queryTri[0], queryTri[1])) in compatEdges and \ 140 ((tgtTri[0], tgtTri[2]), (queryTri[0], queryTri[2])) in compatEdges and \ 141 ((tgtTri[1], tgtTri[2]), (queryTri[1], queryTri[2])) in compatEdges: 142 queryLocs = [queryPts[x].location for x in queryTri] 143 ssd, tf = Alignment.GetAlignmentTransform(tgtLocs, queryLocs) 144 if ssd <= ssdTol: 145 alg = SubshapeAlignment() 146 alg.transform = tf 147 alg.triangleSSD = ssd 148 alg.targetTri = tgtTri 149 alg.queryTri = queryTri 150 alg._seqNo = seqNo 151 seqNo += 1 152 yield alg
153
154 - def _checkMatchFeatures(self, targetPts, queryPts, alignment):
155 nMatched = 0 156 for i in range(3): 157 tgtFeats = targetPts[alignment.targetTri[i]].molFeatures 158 qFeats = queryPts[alignment.queryTri[i]].molFeatures 159 if not tgtFeats and not qFeats: 160 nMatched += 1 161 else: 162 for jFeat in tgtFeats: 163 if jFeat in qFeats: 164 nMatched += 1 165 break 166 if nMatched >= self.numFeatThresh: 167 break 168 return nMatched >= self.numFeatThresh
169
170 - def PruneMatchesUsingFeatures(self, target, query, alignments, pruneStats=None):
171 i = 0 172 targetPts = target.skelPts 173 queryPts = query.skelPts 174 while i < len(alignments): 175 alg = alignments[i] 176 if not self._checkMatchFeatures(targetPts, queryPts, alg): 177 if pruneStats is not None: 178 pruneStats['features'] = pruneStats.get('features', 0) + 1 179 del alignments[i] 180 else: 181 i += 1
182
183 - def _checkMatchDirections(self, targetPts, queryPts, alignment):
184 dot = 0.0 185 for i in range(3): 186 tgtPt = targetPts[alignment.targetTri[i]] 187 queryPt = queryPts[alignment.queryTri[i]] 188 qv = queryPt.shapeDirs[0] 189 tv = tgtPt.shapeDirs[0] 190 rotV = [0.0] * 3 191 rotV[0] = alignment.transform[0, 0] * qv[0] + alignment.transform[0, 1] * qv[ 192 1] + alignment.transform[0, 2] * qv[2] 193 rotV[1] = alignment.transform[1, 0] * qv[0] + alignment.transform[1, 1] * qv[ 194 1] + alignment.transform[1, 2] * qv[2] 195 rotV[2] = alignment.transform[2, 0] * qv[0] + alignment.transform[2, 1] * qv[ 196 1] + alignment.transform[2, 2] * qv[2] 197 dot += abs(rotV[0] * tv[0] + rotV[1] * tv[1] + rotV[2] * tv[2]) 198 if dot >= self.dirThresh: 199 # already above the threshold, no need to continue 200 break 201 alignment.dirMatch = dot 202 return dot >= self.dirThresh
203
204 - def PruneMatchesUsingDirection(self, target, query, alignments, pruneStats=None):
205 i = 0 206 tgtPts = target.skelPts 207 queryPts = query.skelPts 208 while i < len(alignments): 209 if not self._checkMatchDirections(tgtPts, queryPts, alignments[i]): 210 if pruneStats is not None: 211 pruneStats['direction'] = pruneStats.get('direction', 0) + 1 212 del alignments[i] 213 else: 214 i += 1
215
216 - def _addCoarseAndMediumGrids(self, mol, tgt, confId, builder):
217 oSpace = builder.gridSpacing 218 if mol: 219 builder.gridSpacing = oSpace * 1.5 220 tgt.medGrid = builder.GenerateSubshapeShape(mol, confId, addSkeleton=False) 221 builder.gridSpacing = oSpace * 2 222 tgt.coarseGrid = builder.GenerateSubshapeShape(mol, confId, addSkeleton=False) 223 builder.gridSpacing = oSpace 224 else: 225 tgt.medGrid = builder.SampleSubshape(tgt, oSpace * 1.5) 226 tgt.coarseGrid = builder.SampleSubshape(tgt, oSpace * 2.0)
227
228 - def _checkMatchShape(self, targetMol, target, queryMol, query, alignment, builder, targetConf, 229 queryConf, pruneStats=None, tConfId=1001):
230 matchOk = True 231 TransformMol(queryMol, alignment.transform, confId=queryConf, newConfId=tConfId) 232 oSpace = builder.gridSpacing 233 builder.gridSpacing = oSpace * 2 234 coarseGrid = builder.GenerateSubshapeShape(queryMol, tConfId, addSkeleton=False) 235 d = GetShapeShapeDistance(coarseGrid, target.coarseGrid, self.distMetric) 236 if d > self.shapeDistTol * self.coarseGridToleranceMult: 237 matchOk = False 238 if pruneStats is not None: 239 pruneStats['coarseGrid'] = pruneStats.get('coarseGrid', 0) + 1 240 else: 241 builder.gridSpacing = oSpace * 1.5 242 medGrid = builder.GenerateSubshapeShape(queryMol, tConfId, addSkeleton=False) 243 d = GetShapeShapeDistance(medGrid, target.medGrid, self.distMetric) 244 if d > self.shapeDistTol * self.medGridToleranceMult: 245 matchOk = False 246 if pruneStats is not None: 247 pruneStats['medGrid'] = pruneStats.get('medGrid', 0) + 1 248 else: 249 builder.gridSpacing = oSpace 250 fineGrid = builder.GenerateSubshapeShape(queryMol, tConfId, addSkeleton=False) 251 d = GetShapeShapeDistance(fineGrid, target, self.distMetric) 252 if d > self.shapeDistTol: 253 matchOk = False 254 if pruneStats is not None: 255 pruneStats['fineGrid'] = pruneStats.get('fineGrid', 0) + 1 256 alignment.shapeDist = d 257 queryMol.RemoveConformer(tConfId) 258 builder.gridSpacing = oSpace 259 return matchOk
260
261 - def PruneMatchesUsingShape(self, targetMol, target, queryMol, query, builder, alignments, 262 tgtConf=-1, queryConf=-1, pruneStats=None):
263 if not hasattr(target, 'medGrid'): 264 self._addCoarseAndMediumGrids(targetMol, target, tgtConf, builder) 265 266 logger.info("Shape-based Pruning") 267 i = 0 268 nOrig = len(alignments) 269 nDone = 0 270 while i < len(alignments): 271 alg = alignments[i] 272 nDone += 1 273 if not nDone % 100: 274 nLeft = len(alignments) 275 logger.info(' processed %d of %d. %d alignments remain' % ((nDone, nOrig, nLeft))) 276 if not self._checkMatchShape(targetMol, target, queryMol, query, alg, builder, 277 targetConf=tgtConf, queryConf=queryConf, pruneStats=pruneStats): 278 del alignments[i] 279 else: 280 i += 1
281
282 - def GetSubshapeAlignments(self, targetMol, target, queryMol, query, builder, tgtConf=-1, 283 queryConf=-1, pruneStats=None):
284 import time 285 if pruneStats is None: 286 pruneStats = {} 287 logger.info("Generating triangle matches") 288 t1 = time.time() 289 res = [x for x in self.GetTriangleMatches(target, query)] 290 t2 = time.time() 291 logger.info("Got %d possible alignments in %.1f seconds" % (len(res), t2 - t1)) 292 pruneStats['gtm_time'] = t2 - t1 293 if builder.featFactory: 294 logger.info("Doing feature pruning") 295 t1 = time.time() 296 self.PruneMatchesUsingFeatures(target, query, res, pruneStats=pruneStats) 297 t2 = time.time() 298 pruneStats['feats_time'] = t2 - t1 299 logger.info("%d possible alignments remain. (%.1f seconds required)" % (len(res), t2 - t1)) 300 logger.info("Doing direction pruning") 301 t1 = time.time() 302 self.PruneMatchesUsingDirection(target, query, res, pruneStats=pruneStats) 303 t2 = time.time() 304 pruneStats['direction_time'] = t2 - t1 305 logger.info("%d possible alignments remain. (%.1f seconds required)" % (len(res), t2 - t1)) 306 t1 = time.time() 307 self.PruneMatchesUsingShape(targetMol, target, queryMol, query, builder, res, tgtConf=tgtConf, 308 queryConf=queryConf, pruneStats=pruneStats) 309 t2 = time.time() 310 pruneStats['shape_time'] = t2 - t1 311 return res
312
313 - def __call__(self, targetMol, target, queryMol, query, builder, tgtConf=-1, queryConf=-1, 314 pruneStats=None):
315 for alignment in self.GetTriangleMatches(target, query): 316 if (not self._checkMatchFeatures(target.skelPts, query.skelPts, alignment) and 317 builder.featFactory): 318 if pruneStats is not None: 319 pruneStats['features'] = pruneStats.get('features', 0) + 1 320 continue 321 if not self._checkMatchDirections(target.skelPts, query.skelPts, alignment): 322 if pruneStats is not None: 323 pruneStats['direction'] = pruneStats.get('direction', 0) + 1 324 continue 325 326 if not hasattr(target, 'medGrid'): 327 self._addCoarseAndMediumGrids(targetMol, target, tgtConf, builder) 328 329 if not self._checkMatchShape(targetMol, target, queryMol, query, alignment, builder, 330 targetConf=tgtConf, queryConf=queryConf, pruneStats=pruneStats): 331 continue 332 # if we made it this far, it's a good alignment 333 yield alignment
334 335 336 if __name__ == '__main__': # pragma: nocover 337 from rdkit.six.moves import cPickle 338 from rdkit.Chem.PyMol import MolViewer 339 with open('target.pkl', 'rb') as f: 340 tgtMol, tgtShape = cPickle.load(f) 341 with open('query.pkl', 'rb') as f: 342 queryMol, queryShape = cPickle.load(f) 343 with open('builder.pkl', 'rb') as f: 344 builder = cPickle.load(f) 345 aligner = SubshapeAligner() 346 algs = aligner.GetSubshapeAlignments(tgtMol, tgtShape, queryMol, queryShape, builder) 347 print(len(algs)) 348 349 v = MolViewer() 350 v.ShowMol(tgtMol, name='Target', showOnly=True) 351 v.ShowMol(queryMol, name='Query', showOnly=False) 352 SubshapeObjects.DisplaySubshape(v, tgtShape, 'target_shape', color=(.8, .2, .2)) 353 SubshapeObjects.DisplaySubshape(v, queryShape, 'query_shape', color=(.2, .2, .8)) 354