1
2
3
4
5
6 """ functionality to allow adjusting composite model contents
7
8 """
9 from __future__ import print_function
10
11 import copy
12
13 import numpy
14
15
16 -def BalanceComposite(model, set1, set2, weight, targetSize, names1=None, names2=None):
17 """ adjusts the contents of the composite model so as to maximize
18 the weighted classification accuracty across the two data sets.
19
20 The resulting composite model, with _targetSize_ models, is returned.
21
22 **Notes**:
23
24 - if _names1_ and _names2_ are not provided, _set1_ and _set2_ should
25 have the same ordering of columns and _model_ should have already
26 have had _SetInputOrder()_ called.
27
28 """
29
30
31
32
33
34 S1 = len(set1)
35 S2 = len(set2)
36 weight1 = float(S1 + S2) * (1 - weight) / S1
37 weight2 = float(S1 + S2) * weight / S2
38
39
40
41 res = copy.copy(model)
42 res.modelList = []
43 res.errList = []
44 res.countList = []
45 res.quantizationRequirements = []
46
47 startSize = len(model)
48 scores = numpy.zeros(startSize, numpy.float)
49 actQuantBounds = model.GetActivityQuantBounds()
50 if names1 is not None:
51 model.SetInputOrder(names1)
52 for pt in set1:
53 pred, conf = model.ClassifyExample(pt)
54 if actQuantBounds:
55 ans = model.QuantizeActivity(pt)[-1]
56 else:
57 ans = pt[-1]
58 votes = model.GetVoteDetails()
59 for i in range(startSize):
60 if votes[i] == ans:
61 scores[i] += weight1
62 if names2 is not None:
63 model.SetInputOrder(names2)
64 for pt in set2:
65 pred, conf = model.ClassifyExample(pt)
66 if actQuantBounds:
67 ans = model.QuantizeActivity(pt)[-1]
68 else:
69 ans = pt[-1]
70 votes = model.GetVoteDetails()
71 for i in range(startSize):
72 if votes[i] == ans:
73 scores[i] += weight2
74
75 nPts = S1 + S2
76 scores /= nPts
77
78 bestOrder = list(numpy.argsort(scores))
79 bestOrder.reverse()
80 print('\tTAKE:', bestOrder[:targetSize])
81
82 for i in range(targetSize):
83 idx = bestOrder[i]
84 mdl = model.modelList[idx]
85 res.modelList.append(mdl)
86 res.errList.append(1. - scores[idx])
87 res.countList.append(1)
88
89 res.quantizationRequirements.append(0)
90 return res
91