Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
obj-x86_64-linux-gnu
examples
Supervised
CSvmGridSearchTutorial.cpp
Go to the documentation of this file.
1
#include <
shark/Models/Kernels/GaussianRbfKernel.h
>
2
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
3
#include <
shark/Algorithms/Trainers/CSvmTrainer.h
>
4
#include <
shark/Data/DataDistribution.h
>
5
6
#include <
shark/ObjectiveFunctions/CrossValidationError.h
>
7
#include <
shark/Algorithms/DirectSearch/GridSearch.h
>
8
#include <
shark/Algorithms/JaakkolaHeuristic.h
>
9
10
using namespace
shark
;
11
using namespace
std
;
12
13
14
int
main
()
15
{
16
// problem definition
17
Chessboard
prob;
18
ClassificationDataset
dataTrain = prob.
generateDataset
(200);
19
ClassificationDataset
dataTest= prob.
generateDataset
(10000);
20
21
// SVM setup
22
GaussianRbfKernel<>
kernel(0.5,
true
);
//unconstrained?
23
KernelClassifier<RealVector>
svm;
24
bool
offset =
true
;
25
bool
unconstrained =
true
;
26
CSvmTrainer<RealVector>
trainer(&kernel, 1.0, offset,unconstrained);
27
28
// cross-validation error
29
const
unsigned
int
K = 5;
// number of folds
30
ZeroOneLoss<unsigned int>
loss;
31
CVFolds<ClassificationDataset>
folds =
createCVSameSizeBalanced
(dataTrain, K);
32
CrossValidationError<KernelClassifier<RealVector>
,
unsigned
int
> cvError(
33
folds, &trainer, &svm, &trainer, &loss
34
);
35
36
37
// find best parameters
38
39
// use Jaakkola's heuristic as a starting point for the grid-search
40
JaakkolaHeuristic
ja(dataTrain);
41
double
ljg = log(ja.
gamma
());
42
cout <<
"Tommi Jaakkola says gamma = "
<< ja.
gamma
() <<
" and ln(gamma) = "
<< ljg << endl;
43
44
GridSearch
grid;
45
vector<double>
min
(2);
46
vector<double>
max
(2);
47
vector<size_t> sections(2);
48
min[0] = ljg-4.; max[0] = ljg+4; sections[0] = 9;
// kernel parameter gamma
49
min[1] = 0.0; max[1] = 10.0; sections[1] = 11;
// regularization parameter C
50
grid.
configure
(min, max, sections);
51
grid.
step
(cvError);
52
53
// train model on the full dataset
54
trainer.
setParameterVector
(grid.
solution
().
point
);
55
trainer.
train
(svm, dataTrain);
56
cout <<
"grid.solution().point "
<< grid.
solution
().
point
<< endl;
57
cout <<
"C =\t"
<< trainer.
C
() << endl;
58
cout <<
"gamma =\t"
<< kernel.
gamma
() << endl;
59
60
// evaluate
61
Data<unsigned int>
output = svm(dataTrain.
inputs
());
62
double
train_error = loss.
eval
(dataTrain.
labels
(), output);
63
cout <<
"training error:\t"
<< train_error << endl;
64
output = svm(dataTest.
inputs
());
65
double
test_error = loss.
eval
(dataTest.
labels
(), output);
66
cout <<
"test error: \t"
<< test_error << endl;
67
}