Package rdkit :: Package ML :: Package DecTree :: Module TreeVis
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.TreeVis

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2002,2003  Greg Landrum and Rational Discovery LLC 
  4  #    All Rights Reserved 
  5  # 
  6  """ functionality for drawing trees on sping canvases 
  7   
  8  """ 
  9  import math 
 10   
 11  from rdkit.sping import pid as piddle 
 12   
 13   
14 -class VisOpts(object):
15 circRad = 10 16 minCircRad = 4 17 maxCircRad = 16 18 circColor = piddle.Color(0.6, 0.6, 0.9) 19 terminalEmptyColor = piddle.Color(.8, .8, .2) 20 terminalOnColor = piddle.Color(0.8, 0.8, 0.8) 21 terminalOffColor = piddle.Color(0.2, 0.2, 0.2) 22 outlineColor = piddle.transparent 23 lineColor = piddle.Color(0, 0, 0) 24 lineWidth = 2 25 horizOffset = 10 26 vertOffset = 50 27 labelFont = piddle.Font(face='helvetica', size=10) 28 highlightColor = piddle.Color(1., 1., .4) 29 highlightWidth = 2
30 31 32 visOpts = VisOpts() 33 34
35 -def CalcTreeNodeSizes(node):
36 """Recursively calculate the total number of nodes under us. 37 38 results are set in node.totNChildren for this node and 39 everything underneath it. 40 """ 41 children = node.GetChildren() 42 if len(children) > 0: 43 nHere = 0 44 nBelow = 0 45 for child in children: 46 CalcTreeNodeSizes(child) 47 nHere = nHere + child.totNChildren 48 if child.nLevelsBelow > nBelow: 49 nBelow = child.nLevelsBelow 50 else: 51 nBelow = 0 52 nHere = 1 53 54 node.nExamples = len(node.GetExamples()) 55 node.totNChildren = nHere 56 node.nLevelsBelow = nBelow + 1
57 58
59 -def _ExampleCounter(node, min, max):
60 if node.GetTerminal(): 61 cnt = node.nExamples 62 if cnt < min: 63 min = cnt 64 if cnt > max: 65 max = cnt 66 else: 67 for child in node.GetChildren(): 68 provMin, provMax = _ExampleCounter(child, min, max) 69 if provMin < min: 70 min = provMin 71 if provMax > max: 72 max = provMax 73 return min, max
74 75
76 -def _ApplyNodeScales(node, min, max):
77 if node.GetTerminal(): 78 if max != min: 79 loc = float(node.nExamples - min) / (max - min) 80 else: 81 loc = .5 82 node._scaleLoc = loc 83 else: 84 for child in node.GetChildren(): 85 _ApplyNodeScales(child, min, max)
86 87
88 -def SetNodeScales(node):
89 min, max = 1e8, -1e8 90 min, max = _ExampleCounter(node, min, max) 91 node._scales = min, max 92 _ApplyNodeScales(node, min, max)
93 94
95 -def DrawTreeNode(node, loc, canvas, nRes=2, scaleLeaves=False, showPurity=False):
96 """Recursively displays the given tree node and all its children on the canvas 97 """ 98 try: 99 nChildren = node.totNChildren 100 except AttributeError: 101 nChildren = None 102 if nChildren is None: 103 CalcTreeNodeSizes(node) 104 105 if not scaleLeaves or not node.GetTerminal(): 106 rad = visOpts.circRad 107 else: 108 scaleLoc = getattr(node, "_scaleLoc", 0.5) 109 110 rad = visOpts.minCircRad + node._scaleLoc * (visOpts.maxCircRad - visOpts.minCircRad) 111 112 x1 = loc[0] - rad 113 y1 = loc[1] - rad 114 x2 = loc[0] + rad 115 y2 = loc[1] + rad 116 117 if showPurity and node.GetTerminal(): 118 examples = node.GetExamples() 119 nEx = len(examples) 120 if nEx: 121 tgtVal = int(node.GetLabel()) 122 purity = 0.0 123 for ex in examples: 124 if int(ex[-1]) == tgtVal: 125 purity += 1. / len(examples) 126 else: 127 purity = 1.0 128 129 deg = purity * math.pi 130 xFact = rad * math.sin(deg) 131 yFact = rad * math.cos(deg) 132 pureX = loc[0] + xFact 133 pureY = loc[1] + yFact 134 135 children = node.GetChildren() 136 # just move down one level 137 childY = loc[1] + visOpts.vertOffset 138 # this is the left-hand side of the leftmost span 139 childX = loc[0] - ((visOpts.horizOffset + visOpts.circRad) * node.totNChildren) / 2 140 for i in range(len(children)): 141 # center on this child's space 142 child = children[i] 143 halfWidth = ((visOpts.horizOffset + visOpts.circRad) * child.totNChildren) / 2 144 145 childX = childX + halfWidth 146 childLoc = [childX, childY] 147 canvas.drawLine(loc[0], loc[1], childLoc[0], childLoc[1], visOpts.lineColor, visOpts.lineWidth) 148 DrawTreeNode(child, childLoc, canvas, nRes=nRes, scaleLeaves=scaleLeaves, showPurity=showPurity) 149 150 # and move over to the leftmost point of the next child 151 childX = childX + halfWidth 152 153 if node.GetTerminal(): 154 lab = node.GetLabel() 155 cFac = float(lab) / float(nRes - 1) 156 if hasattr(node, 'GetExamples') and node.GetExamples(): 157 theColor = (1. - cFac) * visOpts.terminalOffColor + cFac * visOpts.terminalOnColor 158 outlColor = visOpts.outlineColor 159 else: 160 theColor = (1. - cFac) * visOpts.terminalOffColor + cFac * visOpts.terminalOnColor 161 outlColor = visOpts.terminalEmptyColor 162 canvas.drawEllipse(x1, y1, x2, y2, outlColor, visOpts.lineWidth, theColor) 163 if showPurity: 164 canvas.drawLine(loc[0], loc[1], pureX, pureY, piddle.Color(1, 1, 1), 2) 165 else: 166 theColor = visOpts.circColor 167 canvas.drawEllipse(x1, y1, x2, y2, visOpts.outlineColor, visOpts.lineWidth, theColor) 168 169 # this does not need to be done every time 170 canvas.defaultFont = visOpts.labelFont 171 172 labelStr = str(node.GetLabel()) 173 strLoc = (loc[0] - canvas.stringWidth(labelStr) / 2, loc[1] + canvas.fontHeight() / 4) 174 175 canvas.drawString(labelStr, strLoc[0], strLoc[1]) 176 node._bBox = (x1, y1, x2, y2)
177 178
179 -def CalcTreeWidth(tree):
180 try: 181 tree.totNChildren 182 except AttributeError: 183 CalcTreeNodeSizes(tree) 184 totWidth = tree.totNChildren * (visOpts.circRad + visOpts.horizOffset) 185 return totWidth
186 187
188 -def DrawTree(tree, canvas, nRes=2, scaleLeaves=False, allowShrink=True, showPurity=False):
189 dims = canvas.size 190 loc = (dims[0] / 2, visOpts.vertOffset) 191 if scaleLeaves: 192 # try: 193 # l = tree._scales 194 # except AttributeError: 195 # l = None 196 # if l is None: 197 SetNodeScales(tree) 198 if allowShrink: 199 treeWid = CalcTreeWidth(tree) 200 while treeWid > dims[0]: 201 visOpts.circRad /= 2 202 visOpts.horizOffset /= 2 203 treeWid = CalcTreeWidth(tree) 204 DrawTreeNode(tree, loc, canvas, nRes, scaleLeaves=scaleLeaves, showPurity=showPurity)
205 206
207 -def ResetTree(tree):
208 tree._scales = None 209 tree.totNChildren = None 210 for child in tree.GetChildren(): 211 ResetTree(child)
212 213
214 -def _simpleTest(canv):
215 from .Tree import TreeNode as Node 216 root = Node(None, 'r', label='r') 217 c1 = root.AddChild('l1_1', label='l1_1') 218 c2 = root.AddChild('l1_2', isTerminal=1, label=1) 219 c3 = c1.AddChild('l2_1', isTerminal=1, label=0) 220 c4 = c1.AddChild('l2_2', isTerminal=1, label=1) 221 222 DrawTreeNode(root, (150, visOpts.vertOffset), canv)
223 224 225 if __name__ == '__main__': 226 from rdkit.sping.PIL.pidPIL import PILCanvas 227 canv = PILCanvas(size=(300, 300), name='test.png') 228 _simpleTest(canv) 229 canv.save() 230