1
2
3
4 """ handles doing cross validation with k-nearest neighbors model
5
6 and evaluation of individual models
7
8 """
9 from __future__ import print_function
10
11 from rdkit.ML.Data import SplitData
12 from rdkit.ML.KNN import DistFunctions
13 from rdkit.ML.KNN.KNNClassificationModel import KNNClassificationModel
14 from rdkit.ML.KNN.KNNRegressionModel import KNNRegressionModel
15
16
19
20
23
24
26 """
27 Determines the classification error for the testExamples
28
29 **Arguments**
30
31 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method)
32
33 - testExamples: a list of examples to be used for testing
34
35 - appendExamples: a toggle which is passed along to the tree as it does
36 the classification. The trees can use this to store the examples they
37 classify locally.
38
39 **Returns**
40
41 a 2-tuple consisting of:
42 """
43 nTest = len(testExamples)
44
45 if isinstance(knnMod, KNNClassificationModel):
46 badExamples = []
47 nBad = 0
48 for i in range(nTest):
49 testEx = testExamples[i]
50 trueRes = testEx[-1]
51 res = knnMod.ClassifyExample(testEx, appendExamples)
52 if (trueRes != res):
53 badExamples.append(testEx)
54 nBad += 1
55 return float(nBad) / nTest, badExamples
56 elif isinstance(knnMod, KNNRegressionModel):
57 devSum = 0.0
58 for i in range(nTest):
59 testEx = testExamples[i]
60 trueRes = testEx[-1]
61 res = knnMod.PredictExample(testEx, appendExamples)
62 devSum += abs(trueRes - res)
63 return devSum / nTest, None
64 raise ValueError("Unrecognized Model Type")
65
66
71 """ Driver function for building a KNN model of a specified type
72
73 **Arguments**
74
75 - examples: the full set of examples
76
77 - numNeigh: number of neighbors for the KNN model (basically k in k-NN)
78
79 - knnModel: the type of KNN model (a classification vs regression model)
80
81 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
82 (used to calculate error)
83
84 - silent: a toggle used to control how much visual noise this makes as it goes
85
86 - calcTotalError: a toggle used to indicate whether the classification error
87 of the tree should be calculated using the entire data set (when true) or just
88 the training hold out set (when false)
89 """
90
91 nTot = len(examples)
92 if not kwargs.get('replacementSelection', 0):
93 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=1,
94 replacement=0)
95 else:
96 testIndices, trainIndices = SplitData.SplitIndices(nTot, holdOutFrac, silent=1, legacy=0,
97 replacement=1)
98 trainExamples = [examples[x] for x in trainIndices]
99 testExamples = [examples[x] for x in testIndices]
100
101 nTrain = len(trainExamples)
102
103 if not silent:
104 print("Training with %d examples" % (nTrain))
105
106 knnMod = modelBuilder(numNeigh, attrs, distFunc)
107
108 knnMod.SetTrainingExamples(trainExamples)
109 knnMod.SetTestExamples(testExamples)
110
111 if not calcTotalError:
112 xValError, _ = CrossValidate(knnMod, testExamples, appendExamples=1)
113 else:
114 xValError, _ = CrossValidate(knnMod, examples, appendExamples=0)
115
116 if not silent:
117 print('Validation error was %%%4.2f' % (100 * xValError))
118
119 knnMod._trainIndices = trainIndices
120 return knnMod, xValError
121