McSvm.cpp
Go to the documentation of this file.
1 #include <cstdio>
2 
3 #include <shark/LinAlg/Base.h>
4 #include <shark/Rng/GlobalRng.h>
5 #include <shark/Data/Dataset.h>
19 
20 
21 using namespace shark;
22 
23 
24 // data generating distribution for our toy
25 // multi-category classification problem
26 /// @cond EXAMPLE_SYMBOLS
27 class Problem : public LabeledDataDistribution<RealVector, unsigned int>
28 {
29 public:
30  void draw(RealVector& input, unsigned int& label)const
31  {
32  label = Rng::discrete(0, 4);
33  input.resize(1);
34  input(0) = Rng::gauss() + 3.0 * label;
35  }
36 };
37 /// @endcond
38 
39 int main()
40 {
41  unsigned int i;
42 
43  // experiment settings
44  unsigned int ell = 30;
45  unsigned int tests = 100;
46  double C = 10.0;
47  double gamma = 0.5;
48 
49  // generate a very simple dataset with a little noise
50  Problem problem;
51  ClassificationDataset training = problem.generateDataset(ell);
52  ClassificationDataset test = problem.generateDataset(tests);
53 
54  // kernel function
55  GaussianRbfKernel<> kernel(gamma);
56 
57  // SVM us kernel classifiers
59 
60  // loss measuring classification errors
62 
63  // There are 9 trainers for multi-class SVMs in Shark which can train with or without bias:
65  trainer[0] = new McSvmOVATrainer<RealVector>(&kernel, C, false);
66  trainer[1] = new McSvmCSTrainer<RealVector>(&kernel, C, false);
67  trainer[2] = new McSvmWWTrainer<RealVector>(&kernel, C, false);
68  trainer[3] = new McSvmLLWTrainer<RealVector>(&kernel, C, false);
69  trainer[4] = new McSvmADMTrainer<RealVector>(&kernel, C, false);
70  trainer[5] = new McSvmATSTrainer<RealVector>(&kernel, C, false);
71  trainer[6] = new McSvmATMTrainer<RealVector>(&kernel, C, false);
72  trainer[7] = new McSvmMMRTrainer<RealVector>(&kernel, C, false);
73  trainer[8] = new McReinforcedSvmTrainer<RealVector>(&kernel, C, false);
74  trainer[9] = new McSvmOVATrainer<RealVector>(&kernel, C, true);
75  trainer[10] = new McSvmCSTrainer<RealVector>(&kernel, C, true);
76  trainer[11] = new McSvmWWTrainer<RealVector>(&kernel, C, true);
77  trainer[12] = new McSvmLLWTrainer<RealVector>(&kernel, C, true);
78  trainer[13] = new McSvmADMTrainer<RealVector>(&kernel, C, true);
79  trainer[14] = new McSvmATSTrainer<RealVector>(&kernel, C, true);
80  trainer[15] = new McSvmATMTrainer<RealVector>(&kernel, C, true);
81  trainer[16] = new McSvmMMRTrainer<RealVector>(&kernel, C, true);
82  trainer[17] = new McReinforcedSvmTrainer<RealVector>(&kernel, C, true);
83 
84  std::printf("SHARK multi-class SVM example - training 18 machines:\n");
85  for (i=0; i<18; i++)
86  {
87  trainer[i]->train(svm, training);
88  Data<unsigned int> output = svm(training.inputs());
89  double train_error = loss.eval(training.labels(), output);
90  output = svm(test.inputs());
91  double test_error = loss.eval(test.labels(), output);
92 
93  std::printf(
94  "[%2d] %10s %s iterations=%10d time=%9.4g seconds training error=%9.4g test error=%9.4g\n",
95  i,
96  trainer[i]->name().c_str(),
97  trainer[i]->trainOffset()? "with bias ":"without bias",
98  (int)trainer[i]->solutionProperties().iterations,
99  trainer[i]->solutionProperties().seconds,
100  train_error,
101  test_error
102  );
103  }
104 
105  //clean up
106  for (std::size_t i = 0; i < 18; ++i){
107  delete trainer[i];
108  }
109 }