CVFolds.cpp
Go to the documentation of this file.
1 //header needed for cross validation
3 
4 //headers needed for our test problem
6 #include <shark/Models/FFNet.h>
12 #include <shark/Rng/Uniform.h>
13 
14 //we use an artifical learning problem
16 
17 using namespace shark;
18 using namespace std;
19 
20 ///In this example, you will learn to create and use partitions
21 ///for cross validation.
22 ///This tutorial describes a handmade solution which does not use the Crossvalidation error function
23 ///which is also provided by shark. We do this, because it gives a better on what Cross Validation does.
24 
25 ///The Test Problem receives the regularization parameter and a dataset
26 ///and returns the errror. skip to the main if you are not interested
27 ///in the problem itself. But here you can also see how to create
28 ///regularized error functions. so maybe it's still worth taking a look ;)
29 double trainProblem(const RegressionDataset& training, RegressionDataset const& validation, double regularization){
30  //we create a feed forward network with 20 hidden neurons.
31  //this should create a lot of overfitting when not regularized
33  network.setStructure(1,20,1);
34 
35  //initialize with random weights between -5 and 5
36  //todo: find better initialization for cv-error...
37  RealVector startingPoint(network.numberOfParameters());
38  Uniform<> uni( Rng::globalRng, -5, 5 );
39  std::generate(startingPoint.begin(),startingPoint.end(),uni);
40 
41  //the error function is a combination of MSE and 2-norm error
42  SquaredLoss<> loss;
43  ErrorFunction error(training,&network,&loss);
44  TwoNormRegularizer regularizer;
45 
46  //combine both functions
48  regularizedError.add(error);
49  regularizedError.add(regularization,regularizer);
50 
51 
52  //now train for a number of iterations using Rprop
53  IRpropPlus optimizer;
54  //initialize with our predefined point, since
55  //the combined function can't propose one.
56  optimizer.init(regularizedError,startingPoint);
57  for(unsigned iter = 0; iter != 5000; ++iter)
58  {
59  optimizer.step(regularizedError);
60  }
61 
62  //validate and return the error without regularization
63  return loss(network(validation.inputs()),validation.labels());
64 }
65 
66 
67 /// What is Cross Validation(CV)? In Cross Validation the dataset is partitioned in
68 /// several validation data sets. For a given validation dataset the remainder of the dataset
69 /// - every other validation set - forms the training part. During every evaluation of the error function,
70 /// the problem is solved using the training part and the final error is computed using the validation part.
71 /// The mean of all validation sets trained this way is the final error of the solution found.
72 /// This quite complex procedure is used to minimize the bias introduced by the dataset itself and makes
73 /// overfitting of the solution harder.
74 int main(){
75  //we first create the problem. in this simple tutorial,
76  //it's only the 1D wave function sin(x)/x + noise
77  Wave wave;
78  RegressionDataset dataset;
79  dataset = wave.generateDataset(100);
80 
81  //now we want to create the cv folds. For this, we
82  //use the CVDatasetTools.h. There are a few functions
83  //to create folds. in this case, we create 4
84  //partitions with the same size. so we have 75 train
85  //and 25 validation data points
87 
88  //now we want to use the folds to find the best regularization
89  //parameter for our problem. we use a grid search to accomplish this
90  double bestValidationError = 1e5;
91  double bestRegularization = 0;
92  for (double regularization = 0; regularization < 1.e-4; regularization +=1.e-5) {
93  double result = 0;
94  for (std::size_t fold = 0; fold != folds.size(); ++fold){ //CV
95  // access the fold
96  RegressionDataset training = folds.training(fold);
97  RegressionDataset validation = folds.validation(fold);
98  // train
99  result += trainProblem(training, validation, regularization);
100  }
101  result /= folds.size();
102 
103  // check whether this regularization parameter leads to better results
104  if (result < bestValidationError)
105  {
106  bestValidationError = result;
107  bestRegularization = regularization;
108  }
109 
110  // print status:
111  std::cout << regularization << " " << result << std::endl;
112  }
113 
114  // print the best value found
115  cout << "RESULTS: " << std::endl;
116  cout << "======== " << std::endl;
117  cout << "best validation error: " << bestValidationError << std::endl;
118  cout << "best regularization: " << bestRegularization<< std::endl;
119 }