1
2
3
4
5
6 """ code for dealing with composite models
7
8 For a model to be useable here, it should support the following API:
9
10 - _ClassifyExample(example)_, returns a classification
11
12 Other compatibility notes:
13
14 1) To use _Composite.Grow_ there must be some kind of builder
15 functionality which returns a 2-tuple containing (model,percent accuracy).
16
17 2) The models should be pickleable
18
19 3) It would be very happy if the models support the __cmp__ method so that
20 membership tests used to make sure models are unique work.
21
22
23
24 """
25 from __future__ import print_function
26 import numpy
27 from rdkit.six.moves import cPickle
28 from rdkit.ML.Data import DataUtils
29
30
32 """a composite model
33
34
35 **Notes**
36
37 - adding a model which is already present just results in its count
38 field being incremented and the errors being averaged.
39
40 - typical usage:
41
42 1) grow the composite with AddModel until happy with it
43
44 2) call AverageErrors to calculate the average error values
45
46 3) call SortModels to put things in order by either error or count
47
48 - Composites can support individual models requiring either quantized or
49 nonquantized data. This is done by keeping a set of quantization bounds
50 (_QuantBounds_) in the composite and quantizing data passed in when required.
51 Quantization bounds can be set and interrogated using the
52 _Get/SetQuantBounds()_ methods. When models are added to the composite,
53 it can be indicated whether or not they require quantization.
54
55 - Composites are also capable of extracting relevant variables from longer lists.
56 This is accessible using _SetDescriptorNames()_ to register the descriptors about
57 which the composite cares and _SetInputOrder()_ to tell the composite what the
58 ordering of input vectors will be. **Note** there is a limitation on this: each
59 model needs to take the same set of descriptors as inputs. This could be changed.
60
61 """
62
64 self.modelList = []
65 self.errList = []
66 self.countList = []
67 self.modelVotes = []
68 self.quantBounds = None
69 self.nPossibleVals = None
70 self.quantizationRequirements = []
71 self._descNames = []
72 self._mapOrder = None
73 self.activityQuant = []
74
76 self._modelFilterFrac = modelFilterFrac
77 self._modelFilterVal = modelFilterVal
78
80 """ registers the names of the descriptors this composite uses
81
82 **Arguments**
83
84 - names: a list of descriptor names (strings).
85
86 **NOTE**
87
88 the _names_ list is not
89 copied, so if you modify it later, the composite itself will also be modified.
90
91 """
92 self._descNames = names
93
95 """ returns the names of the descriptors this composite uses
96
97 """
98 return self._descNames
99
101 """ sets the quantization bounds that the composite will use
102
103 **Arguments**
104
105 - qBounds: a list of quantization bounds, each quantbound is a
106 list of boundaries
107
108 - nPossible: a list of integers indicating how many possible values
109 each descriptor can take on.
110
111 **NOTE**
112
113 - if the two lists are of different lengths, this will assert out
114
115 - neither list is copied, so if you modify it later, the composite
116 itself will also be modified.
117
118 """
119 if nPossible is not None:
120 assert len(qBounds) == len(nPossible), 'qBounds/nPossible mismatch'
121 self.quantBounds = qBounds
122 self.nPossibleVals = nPossible
123
125 """ returns the quantization bounds
126
127 **Returns**
128
129 a 2-tuple consisting of:
130
131 1) the list of quantization bounds
132
133 2) the nPossibleVals list
134
135 """
136 return self.quantBounds, self.nPossibleVals
137
139 if not hasattr(self, 'activityQuant'):
140 self.activityQuant = []
141 return self.activityQuant
142
144 self.activityQuant = bounds
145
147 if activityQuant is None:
148 activityQuant = self.activityQuant
149 if activityQuant:
150 example = example[:]
151 act = example[actCol]
152 for box in range(len(activityQuant)):
153 if act < activityQuant[box]:
154 act = box
155 break
156 else:
157 act = box + 1
158 example[actCol] = act
159 return example
160
162 """ quantizes an example
163
164 **Arguments**
165
166 - example: a data point (list, tuple or numpy array)
167
168 - quantBounds: a list of quantization bounds, each quantbound is a
169 list of boundaries. If this argument is not provided, the composite
170 will use its own quantBounds
171
172 **Returns**
173
174 the quantized example as a list
175
176 **Notes**
177
178 - If _example_ is different in length from _quantBounds_, this will
179 assert out.
180
181 - This is primarily intended for internal use
182
183 """
184 if quantBounds is None:
185 quantBounds = self.quantBounds
186 assert len(example) == len(quantBounds), 'example/quantBounds mismatch'
187 quantExample = [None] * len(example)
188 for i in range(len(quantBounds)):
189 bounds = quantBounds[i]
190 p = example[i]
191 if len(bounds):
192 for box in range(len(bounds)):
193 if p < bounds[box]:
194 p = box
195 break
196 else:
197 p = box + 1
198 else:
199 if i != 0:
200 p = int(p)
201 quantExample[i] = p
202 return quantExample
203
205 """ creates a histogram of error/count pairs
206
207 **Returns**
208
209 the histogram as a series of (error, count) 2-tuples
210
211 """
212 nExamples = len(self.modelList)
213 histo = []
214 i = 1
215 lastErr = self.errList[0]
216 countHere = self.countList[0]
217 eps = 0.001
218 while i < nExamples:
219 if self.errList[i] - lastErr > eps:
220 histo.append((lastErr, countHere))
221 lastErr = self.errList[i]
222 countHere = self.countList[i]
223 else:
224 countHere = countHere + self.countList[i]
225 i = i + 1
226
227 return histo
228
229 - def CollectVotes(self, example, quantExample, appendExample=0, onlyModels=None):
230 """ collects votes across every member of the composite for the given example
231
232 **Arguments**
233
234 - example: the example to be voted upon
235
236 - quantExample: the quantized form of the example
237
238 - appendExample: toggles saving the example on the models
239
240 - onlyModels: if provided, this should be a sequence of model
241 indices. Only the specified models will be used in the
242 prediction.
243
244 **Returns**
245
246 a list with a vote from each member
247
248 """
249 if not onlyModels:
250 onlyModels = list(range(len(self)))
251
252 votes = [-1] * len(self)
253 for i in onlyModels:
254 if self.quantizationRequirements[i]:
255 votes[i] = int(
256 round(self.modelList[i].ClassifyExample(quantExample, appendExamples=appendExample)))
257 else:
258 votes[i] = int(
259 round(self.modelList[i].ClassifyExample(example, appendExamples=appendExample)))
260
261 return votes
262
263 - def ClassifyExample(self, example, threshold=0, appendExample=0, onlyModels=None):
264 """ classifies the given example using the entire composite
265
266 **Arguments**
267
268 - example: the data to be classified
269
270 - threshold: if this is a number greater than zero, then a
271 classification will only be returned if the confidence is
272 above _threshold_. Anything lower is returned as -1.
273
274 - appendExample: toggles saving the example on the models
275
276 - onlyModels: if provided, this should be a sequence of model
277 indices. Only the specified models will be used in the
278 prediction.
279
280 **Returns**
281
282 a (result,confidence) tuple
283
284
285 **FIX:**
286 statistics sucks... I'm not seeing an obvious way to get
287 the confidence intervals. For that matter, I'm not seeing
288 an unobvious way.
289
290 For now, this is just treated as a voting problem with the confidence
291 measure being the percent of models which voted for the winning result.
292
293 """
294 if self._mapOrder is not None:
295 example = self._RemapInput(example)
296 if self.GetActivityQuantBounds():
297 example = self.QuantizeActivity(example)
298 if self.quantBounds is not None and 1 in self.quantizationRequirements:
299 quantExample = self.QuantizeExample(example, self.quantBounds)
300 else:
301 quantExample = []
302
303 if not onlyModels:
304 onlyModels = list(range(len(self)))
305 self.modelVotes = self.CollectVotes(example, quantExample, appendExample=appendExample,
306 onlyModels=onlyModels)
307
308 votes = [0] * self.nPossibleVals[-1]
309 for i in onlyModels:
310 res = self.modelVotes[i]
311 votes[res] = votes[res] + self.countList[i]
312
313 totVotes = sum(votes)
314 res = numpy.argmax(votes)
315 conf = float(votes[res]) / float(totVotes)
316 if conf > threshold:
317 return res, conf
318 else:
319 return -1, conf
320
322 """ returns the votes from the last classification
323
324 This will be _None_ if nothing has yet be classified
325 """
326 return self.modelVotes
327
359
365
407
408 - def Grow(self, examples, attrs, nPossibleVals, buildDriver, pruner=None, nTries=10, pruneIt=0,
409 needsQuantization=1, progressCallback=None, **buildArgs):
410 """ Grows the composite
411
412 **Arguments**
413
414 - examples: a list of examples to be used in training
415
416 - attrs: a list of the variables to be used in training
417
418 - nPossibleVals: this is used to provide a list of the number
419 of possible values for each variable. It is used if the
420 local quantBounds have not been set (for example for when you
421 are working with data which is already quantized).
422
423 - buildDriver: the function to call to build the new models
424
425 - pruner: a function used to "prune" (reduce the complexity of)
426 the resulting model.
427
428 - nTries: the number of new models to add
429
430 - pruneIt: toggles whether or not pruning is done
431
432 - needsQuantization: used to indicate whether or not this type of model
433 requires quantized data
434
435 - **buildArgs: all other keyword args are passed to _buildDriver_
436
437 **Note**
438
439 - new models are *added* to the existing ones
440
441 """
442 silent = buildArgs.get('silent', 0)
443 buildArgs['silent'] = 1
444 buildArgs['calcTotalError'] = 1
445
446 if self._mapOrder is not None:
447 examples = map(self._RemapInput, examples)
448 if self.GetActivityQuantBounds():
449 for i in range(len(examples)):
450 examples[i] = self.QuantizeActivity(examples[i])
451 nPossibleVals[-1] = len(self.GetActivityQuantBounds()) + 1
452 if self.nPossibleVals is None:
453 self.nPossibleVals = nPossibleVals[:]
454 if needsQuantization:
455 trainExamples = [None] * len(examples)
456 nPossibleVals = self.nPossibleVals
457 for i in range(len(examples)):
458 trainExamples[i] = self.QuantizeExample(examples[i], self.quantBounds)
459 else:
460 trainExamples = examples
461
462 for i in range(nTries):
463 trainSet = None
464
465 if (hasattr(self, '_modelFilterFrac')) and (self._modelFilterFrac != 0):
466 trainIdx, _ = DataUtils.FilterData(trainExamples, self._modelFilterVal,
467 self._modelFilterFrac, -1, indicesOnly=1)
468 trainSet = [trainExamples[x] for x in trainIdx]
469
470 else:
471 trainSet = trainExamples
472
473
474 model, frac = buildDriver(*(trainSet, attrs, nPossibleVals), **buildArgs)
475 if pruneIt:
476 model, frac2 = pruner(model, model.GetTrainingExamples(), model.GetTestExamples(),
477 minimizeTestErrorOnly=0)
478 frac = frac2
479 if (hasattr(self, '_modelFilterFrac') and self._modelFilterFrac != 0 and
480 hasattr(model, '_trainIndices')):
481
482 trainIndices = [trainIdx[x] for x in model._trainIndices]
483 model._trainIndices = trainIndices
484
485 self.AddModel(model, frac, needsQuantization)
486 if not silent and (nTries < 10 or i % (nTries / 10) == 0):
487 print('Cycle: % 4d' % (i))
488 if progressCallback is not None:
489 progressCallback(i)
490
492 for i in range(len(self)):
493 m = self.GetModel(i)
494 try:
495 m.ClearExamples()
496 except AttributeError:
497 pass
498
499 - def Pickle(self, fileName='foo.pkl', saveExamples=0):
500 """ Writes this composite off to a file so that it can be easily loaded later
501
502 **Arguments**
503
504 - fileName: the name of the file to be written
505
506 - saveExamples: if this is zero, the individual models will have
507 their stored examples cleared.
508
509 """
510 if not saveExamples:
511 self.ClearModelExamples()
512
513 pFile = open(fileName, 'wb+')
514 cPickle.dump(self, pFile, 1)
515 pFile.close()
516
517 - def AddModel(self, model, error, needsQuantization=1):
518 """ Adds a model to the composite
519
520 **Arguments**
521
522 - model: the model to be added
523
524 - error: the model's error
525
526 - needsQuantization: a toggle to indicate whether or not this model
527 requires quantized inputs
528
529 **NOTE**
530
531 - this can be used as an alternative to _Grow()_ if you already have
532 some models constructed
533
534 - the errList is run as an accumulator,
535 you probably want to call _AverageErrors_ after finishing the forest
536
537 """
538 if model in self.modelList:
539 try:
540 idx = self.modelList.index(model)
541 except ValueError:
542
543 self.modelList.append(model)
544 self.errList.append(error)
545 self.countList.append(1)
546 self.quantizationRequirements.append(needsQuantization)
547 else:
548 self.errList[idx] = self.errList[idx] + error
549 self.countList[idx] = self.countList[idx] + 1
550 else:
551 self.modelList.append(model)
552 self.errList.append(error)
553 self.countList.append(1)
554 self.quantizationRequirements.append(needsQuantization)
555
557 """ convert local summed error to average error
558
559 """
560 self.errList = list(map(lambda x, y: x / y, self.errList, self.countList))
561
563 """ sorts the list of models
564
565 **Arguments**
566
567 sortOnError: toggles sorting on the models' errors rather than their counts
568
569
570 """
571 if sortOnError:
572 order = numpy.argsort(self.errList)
573 else:
574 order = numpy.argsort(self.countList)
575
576
577
578
579 self.modelList = [self.modelList[x] for x in order]
580 self.countList = [self.countList[x] for x in order]
581 self.errList = [self.errList[x] for x in order]
582
584 """ returns a particular model
585
586 """
587 return self.modelList[i]
588
590 """ replaces a particular model
591
592 **Note**
593
594 This is included for the sake of completeness, but you need to be
595 *very* careful when you use it.
596
597 """
598 self.modelList[i] = val
599
601 """ returns the count of the _i_th model
602
603 """
604 return self.countList[i]
605
607 """ sets the count of the _i_th model
608
609 """
610 self.countList[i] = val
611
613 """ returns the error of the _i_th model
614
615 """
616 return self.errList[i]
617
619 """ sets the error of the _i_th model
620
621 """
622 self.errList[i] = val
623
625 """ returns all relevant data about a particular model
626
627 **Arguments**
628
629 i: an integer indicating which model should be returned
630
631 **Returns**
632
633 a 3-tuple consisting of:
634
635 1) the model
636
637 2) its count
638
639 3) its error
640 """
641 return (self.modelList[i], self.countList[i], self.errList[i])
642
644 """ sets all relevant data for a particular tree in the forest
645
646 **Arguments**
647
648 - i: an integer indicating which model should be returned
649
650 - tup: a 3-tuple consisting of:
651
652 1) the model
653
654 2) its count
655
656 3) its error
657
658 **Note**
659
660 This is included for the sake of completeness, but you need to be
661 *very* careful when you use it.
662
663 """
664 self.modelList[i], self.countList[i], self.errList[i] = tup
665
667 """ Returns everything we know
668
669 **Returns**
670
671 a 3-tuple consisting of:
672
673 1) our list of models
674
675 2) our list of model counts
676
677 3) our list of model errors
678
679 """
680 return (self.modelList, self.countList, self.errList)
681
683 """ allows len(composite) to work
684
685 """
686 return len(self.modelList)
687
689 """ allows composite[i] to work, returns the data tuple
690
691 """
692 return self.GetDataTuple(which)
693
695 """ returns a string representation of the composite
696
697 """
698 outStr = 'Composite\n'
699 for i in range(len(self.modelList)):
700 outStr = (outStr + ' Model %4d: %5d occurances %%%5.2f average error\n' %
701 (i, self.countList[i], 100. * self.errList[i]))
702 return outStr
703
704
705 if __name__ == '__main__':
706 if 0:
707 from rdkit.ML.DecTree import DecTree
708 c = Composite()
709 n = DecTree.DecTreeNode(None, 'foo')
710 c.AddModel(n, 0.5)
711 c.AddModel(n, 0.5)
712 c.AverageErrors()
713 c.SortModels()
714 print(c)
715
716 qB = [[], [.5, 1, 1.5]]
717 exs = [['foo', 0], ['foo', .4], ['foo', .6], ['foo', 1.1], ['foo', 2.0]]
718 print('quantBounds:', qB)
719 for ex in exs:
720 q = c.QuantizeExample(ex, qB)
721 print(ex, q)
722 else:
723 pass
724