1
2
3
4
5
6 """ Implementation of the clustering algorithm published in:
7 Butina JCICS 39 747-750 (1999)
8
9 """
10 import numpy
11 from rdkit import RDLogger
12 logger = RDLogger.logger()
13
14
16 dv = numpy.array(pi) - numpy.array(pj)
17 return numpy.sqrt(dv * dv)
18
19
21 """ clusters the data points passed in and returns the list of clusters
22
23 **Arguments**
24
25 - data: a list of items with the input data
26 (see discussion of _isDistData_ argument for the exception)
27
28 - nPts: the number of points to be used
29
30 - distThresh: elements within this range of each other are considered
31 to be neighbors
32
33 - isDistData: set this toggle when the data passed in is a
34 distance matrix. The distance matrix should be stored
35 symmetrically. An example of how to do this:
36
37 dists = []
38 for i in range(nPts):
39 for j in range(i):
40 dists.append( distfunc(i,j) )
41
42 - distFunc: a function to calculate distances between points.
43 Receives 2 points as arguments, should return a float
44
45 - reodering: if this toggle is set, the number of neighbors is updated
46 for the unassigned molecules after a new cluster is created such
47 that always the molecule with the largest number of unassigned
48 neighbors is selected as the next cluster center.
49
50 **Returns**
51
52 - a tuple of tuples containing information about the clusters:
53 ( (cluster1_elem1, cluster1_elem2, ...),
54 (cluster2_elem1, cluster2_elem2, ...),
55 ...
56 )
57 The first element for each cluster is its centroid.
58
59 """
60 if isDistData and len(data) > (nPts * (nPts - 1) / 2):
61 logger.warning("Distance matrix is too long")
62 nbrLists = [None] * nPts
63 for i in range(nPts):
64 nbrLists[i] = []
65
66 dmIdx = 0
67 for i in range(nPts):
68 for j in range(i):
69 if not isDistData:
70 dij = distFunc(data[i], data[j])
71 else:
72 dij = data[dmIdx]
73 dmIdx += 1
74 if dij <= distThresh:
75 nbrLists[i].append(j)
76 nbrLists[j].append(i)
77
78 tLists = [(len(y), x) for x, y in enumerate(nbrLists)]
79 tLists.sort(reverse=True)
80
81 res = []
82 seen = [0] * nPts
83 while tLists:
84 _, idx = tLists.pop(0)
85 if seen[idx]:
86 continue
87 tRes = [idx]
88 for nbr in nbrLists[idx]:
89 if not seen[nbr]:
90 tRes.append(nbr)
91 seen[nbr] = 1
92
93
94
95 if reordering:
96
97
98
99 nbrNbr = [nbrLists[t] for t in tRes]
100 nbrNbr = frozenset().union(*nbrNbr)
101
102
103 for x, y in enumerate(tLists):
104 y1 = y[1]
105 if seen[y1] or (y1 not in nbrNbr):
106 continue
107
108 nbrLists[y1] = set(nbrLists[y1]).difference(tRes)
109 tLists[x] = (len(nbrLists[y1]), y1)
110
111 tLists.sort(reverse=True)
112 res.append(tuple(tRes))
113 return tuple(res)
114