1
2
3
4
5
6 import RDLogger
7 logger = RDLogger.logger()
8 import Chem,Geometry
9 import numpy
10 from Numerics import Alignment
11 from Chem.Subshape import SubshapeObjects
12
21
23 for i in range(len(pts)):
24 if orderedTraversal:
25 jStart=i+1
26 else:
27 jStart=0
28 for j in range(jStart,len(pts)):
29 if j==i:
30 continue
31 if orderedTraversal:
32 kStart=j+1
33 else:
34 kStart=0
35 for k in range(j+1,len(pts)):
36 if k==i or k==j:
37 continue
38 yield (i,j,k)
39
43
44
56
57
62 from ML.Cluster import Butina
63 dists = []
64 for i in range(len(alignments)):
65 TransformMol(mol,alignments[i].transform,newConfId=tempConfId)
66 shapeI=builder.GenerateSubshapeShape(mol,tempConfId,addSkeleton=False)
67 for j in range(i):
68 TransformMol(mol,alignments[j].transform,newConfId=tempConfId+1)
69 shapeJ=builder.GenerateSubshapeShape(mol,tempConfId+1,addSkeleton=False)
70 d = GetShapeShapeDistance(shapeI,shapeJ,distMetric)
71 dists.append(d)
72 mol.RemoveConformer(tempConfId+1)
73 mol.RemoveConformer(tempConfId)
74 clusts=Butina.ClusterData(dists,len(alignments),neighborTol,isDistData=True)
75 res = [alignments[x[0]] for x in clusts]
76 return res
77
94
96 triangleRMSTol=1.0
97 distMetric=SubshapeDistanceMetric.PROTRUDE
98 shapeDistTol=0.2
99 numFeatThresh=3
100 dirThresh=2.6
101 edgeTol=6.0
102
103
104 coarseGridToleranceMult=1.0
105 medGridToleranceMult=1.0
106
108 """ this is a generator function returning the possible triangle
109 matches between the two shapes
110 """
111 ssdTol = (self.triangleRMSTol**2)*9
112 res = []
113 tgtPts = target.skelPts
114 queryPts = query.skelPts
115 tgtLs = {}
116 for i in range(len(tgtPts)):
117 for j in range(i+1,len(tgtPts)):
118 l2 = (tgtPts[i].location-tgtPts[j].location).LengthSq()
119 tgtLs[(i,j)]=l2
120 queryLs = {}
121 for i in range(len(queryPts)):
122 for j in range(i+1,len(queryPts)):
123 l2 = (queryPts[i].location-queryPts[j].location).LengthSq()
124 queryLs[(i,j)]=l2
125 compatEdges={}
126 tol2 = self.edgeTol*self.edgeTol
127 for tk,tv in tgtLs.iteritems():
128 for qk,qv in queryLs.iteritems():
129 if abs(tv-qv)<tol2:
130 compatEdges[(tk,qk)]=1
131 seqNo=0
132 for tgtTri in _getAllTriangles(tgtPts,orderedTraversal=True):
133 tgtLocs=[tgtPts[x].location for x in tgtTri]
134 for queryTri in _getAllTriangles(queryPts,orderedTraversal=False):
135 if compatEdges.has_key(((tgtTri[0],tgtTri[1]),(queryTri[0],queryTri[1]))) and \
136 compatEdges.has_key(((tgtTri[0],tgtTri[2]),(queryTri[0],queryTri[2]))) and \
137 compatEdges.has_key(((tgtTri[1],tgtTri[2]),(queryTri[1],queryTri[2]))):
138 queryLocs=[queryPts[x].location for x in queryTri]
139 ssd,tf = Alignment.GetAlignmentTransform(tgtLocs,queryLocs)
140 if ssd<=ssdTol:
141 alg = SubshapeAlignment()
142 alg.transform=tf
143 alg.triangleSSD=ssd
144 alg.targetTri=tgtTri
145 alg.queryTri=queryTri
146 alg._seqNo=seqNo
147 seqNo+=1
148 yield alg
149
151 nMatched=0
152 for i in range(3):
153 tgtFeats = targetPts[alignment.targetTri[i]].molFeatures
154 qFeats = queryPts[alignment.queryTri[i]].molFeatures
155 if not tgtFeats and not qFeats:
156 nMatched+=1
157 else:
158 for j,jFeat in enumerate(tgtFeats):
159 if jFeat in qFeats:
160 nMatched+=1
161 break
162 if nMatched>=self.numFeatThresh:
163 break
164 return nMatched>=self.numFeatThresh
165
167 i = 0
168 targetPts = target.skelPts
169 queryPts = query.skelPts
170 while i<len(alignments):
171 alg = alignments[i]
172 if not self._checkMatchFeatures(targetPts,queryPts,alg):
173 if pruneStats is not None:
174 pruneStats['features']=pruneStats.get('features',0)+1
175 del alignments[i]
176 else:
177 i+=1
178
196
198 i = 0
199 tgtPts = target.skelPts
200 queryPts = query.skelPts
201 while i<len(alignments):
202 if not self._checkMatchDirections(tgtPts,queryPts,alignments[i]):
203 if pruneStats is not None:
204 pruneStats['direction']=pruneStats.get('direction',0)+1
205 del alignments[i]
206 else:
207 i+=1
208
220
221 - def _checkMatchShape(self,targetMol,target,queryMol,query,alignment,builder,
222 targetConf,queryConf,pruneStats=None,tConfId=1001):
223 matchOk=True
224 TransformMol(queryMol,alignment.transform,confId=queryConf,newConfId=tConfId)
225 oSpace=builder.gridSpacing
226 builder.gridSpacing=oSpace*2
227 coarseGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False)
228 d = GetShapeShapeDistance(coarseGrid,target.coarseGrid,self.distMetric)
229 if d>self.shapeDistTol*self.coarseGridToleranceMult:
230 matchOk=False
231 if pruneStats is not None:
232 pruneStats['coarseGrid']=pruneStats.get('coarseGrid',0)+1
233 else:
234 builder.gridSpacing=oSpace*1.5
235 medGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False)
236 d = GetShapeShapeDistance(medGrid,target.medGrid,self.distMetric)
237 if d>self.shapeDistTol*self.medGridToleranceMult:
238 matchOk=False
239 if pruneStats is not None:
240 pruneStats['medGrid']=pruneStats.get('medGrid',0)+1
241 else:
242 builder.gridSpacing=oSpace
243 fineGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False)
244 d = GetShapeShapeDistance(fineGrid,target,self.distMetric)
245
246 if d>self.shapeDistTol:
247 matchOk=False
248 if pruneStats is not None:
249 pruneStats['fineGrid']=pruneStats.get('fineGrid',0)+1
250 alignment.shapeDist=d
251 queryMol.RemoveConformer(tConfId)
252 builder.gridSpacing=oSpace
253 return matchOk
254
255 - def PruneMatchesUsingShape(self,targetMol,target,queryMol,query,builder,
256 alignments,tgtConf=-1,queryConf=-1,
257 pruneStats=None):
258 if not hasattr(target,'medGrid'):
259 self._addCoarseAndMediumGrids(targetMol,target,tgtConf,builder)
260
261 logger.info("Shape-based Pruning")
262 i=0
263 nOrig = len(alignments)
264 nDone=0
265 while i < len(alignments):
266 removeIt=False
267 alg = alignments[i]
268 nDone+=1
269 if not nDone%100:
270 nLeft = len(alignments)
271 logger.info(' processed %d of %d. %d alignments remain'%((nDone,
272 nOrig,
273 nLeft)))
274 if not self._checkMatchShape(targetMol,target,queryMol,query,alg,builder,
275 targetConf=tgtConf,queryConf=queryConf,
276 pruneStats=pruneStats):
277 del alignments[i]
278 else:
279 i+=1
280
281 - def GetSubshapeAlignments(self,targetMol,target,queryMol,query,builder,
282 tgtConf=-1,queryConf=-1,pruneStats=None):
283 import time
284 if pruneStats is None:
285 pruneStats={}
286 logger.info("Generating triangle matches")
287 t1=time.time()
288 res = [x for x in self.GetTriangleMatches(target,query)]
289 t2=time.time()
290 logger.info("Got %d possible alignments in %.1f seconds"%(len(res),t2-t1))
291 pruneStats['gtm_time']=t2-t1
292 if builder.featFactory:
293 logger.info("Doing feature pruning")
294 t1 = time.time()
295 self.PruneMatchesUsingFeatures(target,query,res,pruneStats=pruneStats)
296 t2 = time.time()
297 pruneStats['feats_time']=t2-t1
298 logger.info("%d possible alignments remain. (%.1f seconds required)"%(len(res),t2-t1))
299 logger.info("Doing direction pruning")
300 t1 = time.time()
301 self.PruneMatchesUsingDirection(target,query,res,pruneStats=pruneStats)
302 t2 = time.time()
303 pruneStats['direction_time']=t2-t1
304 logger.info("%d possible alignments remain. (%.1f seconds required)"%(len(res),t2-t1))
305 t1 = time.time()
306 self.PruneMatchesUsingShape(targetMol,target,queryMol,query,builder,res,
307 tgtConf=tgtConf,queryConf=queryConf,
308 pruneStats=pruneStats)
309 t2 = time.time()
310 pruneStats['shape_time']=t2-t1
311 return res
312
313 - def __call__(self,targetMol,target,queryMol,query,builder,
314 tgtConf=-1,queryConf=-1,pruneStats=None):
315 for alignment in self.GetTriangleMatches(target,query):
316 if builder.featFactory and \
317 not self._checkMatchFeatures(target.skelPts,query.skelPts,alignment):
318 if pruneStats is not None:
319 pruneStats['features']=pruneStats.get('features',0)+1
320 continue
321 if not self._checkMatchDirections(target.skelPts,query.skelPts,alignment):
322 if pruneStats is not None:
323 pruneStats['direction']=pruneStats.get('direction',0)+1
324 continue
325
326 if not hasattr(target,'medGrid'):
327 self._addCoarseAndMediumGrids(targetMol,target,tgtConf,builder)
328
329 if not self._checkMatchShape(targetMol,target,queryMol,query,alignment,builder,
330 targetConf=tgtConf,queryConf=queryConf,
331 pruneStats=pruneStats):
332 continue
333
334 yield alignment
335
336
337 if __name__=='__main__':
338 import cPickle
339 tgtMol,tgtShape = cPickle.load(file('target.pkl','rb'))
340 queryMol,queryShape = cPickle.load(file('query.pkl','rb'))
341 builder = cPickle.load(file('builder.pkl','rb'))
342 aligner = SubshapeAligner()
343 algs = aligner.GetSubshapeAlignments(tgtMol,tgtShape,queryMol,queryShape,builder)
344 print len(algs)
345
346 from Chem.PyMol import MolViewer
347 v = MolViewer()
348 v.ShowMol(tgtMol,name='Target',showOnly=True)
349 v.ShowMol(queryMol,name='Query',showOnly=False)
350 SubshapeObjects.DisplaySubshape(v,tgtShape,'target_shape',color=(.8,.2,.2))
351 SubshapeObjects.DisplaySubshape(v,queryShape,'query_shape',color=(.2,.2,.8))
352