Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
include
shark
Algorithms
Trainers
McSvmOVATrainer.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Trainer for One-versus-all (one-versus-rest) Multi-class Support Vector Machines
6
*
7
*
8
*
9
*
10
* \author T. Glasmachers
11
* \date -
12
*
13
*
14
* \par Copyright 1995-2016 Shark Development Team
15
*
16
* <BR><HR>
17
* This file is part of Shark.
18
* <http://image.diku.dk/shark/>
19
*
20
* Shark is free software: you can redistribute it and/or modify
21
* it under the terms of the GNU Lesser General Public License as published
22
* by the Free Software Foundation, either version 3 of the License, or
23
* (at your option) any later version.
24
*
25
* Shark is distributed in the hope that it will be useful,
26
* but WITHOUT ANY WARRANTY; without even the implied warranty of
27
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28
* GNU Lesser General Public License for more details.
29
*
30
* You should have received a copy of the GNU Lesser General Public License
31
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
32
*
33
*/
34
//===========================================================================
35
36
37
#ifndef SHARK_ALGORITHMS_MCSVMOVATRAINER_H
38
#define SHARK_ALGORITHMS_MCSVMOVATRAINER_H
39
40
41
#include <
shark/Algorithms/Trainers/AbstractSvmTrainer.h
>
42
#include <
shark/Algorithms/Trainers/CSvmTrainer.h
>
43
44
45
namespace
shark
{
46
47
48
///
49
/// \brief Training of a multi-category SVM by the one-versus-all (OVA) method.
50
///
51
/// This is a special support vector machine variant for
52
/// classification of more than two classes. Given are data
53
/// tuples \f$ (x_i, y_i) \f$ with x-component denoting input
54
/// and y-component denoting the label 1, ..., d (see the tutorial on
55
/// label conventions; the implementation uses values 0 to d-1),
56
/// a kernel function k(x, x') and a regularization
57
/// constant C > 0. Let H denote the kernel induced
58
/// reproducing kernel Hilbert space of k, and let \f$ \phi \f$
59
/// denote the corresponding feature map.
60
/// Then the SVM classifier is the function
61
/// \f[
62
/// h(x) = \arg \max (f_c(x))
63
/// \f]
64
/// \f[
65
/// f_c(x) = \langle w_c, \phi(x) \rangle + b_c
66
/// \f]
67
/// \f[
68
/// f = (f_1, \dots, f_d)
69
/// \f]
70
/// with class-wise coefficients w_c and b_c obtained by training
71
/// a standard C-SVM (see CSvmTrainer) with class c as the positive
72
/// and the union of all other classes as the negative class.
73
/// This is often a strong baseline method, and it is usually much
74
/// faster to train than other multi-category SVMs.
75
///
76
template
<
class
InputType,
class
CacheType =
float
>
77
class
McSvmOVATrainer
:
public
AbstractSvmTrainer
<InputType, unsigned int>
78
{
79
public
:
80
typedef
CacheType
QpFloatType
;
81
82
typedef
AbstractModel<InputType, RealVector>
ModelType
;
83
typedef
AbstractKernelFunction<InputType>
KernelType
;
84
typedef
AbstractSvmTrainer<InputType, unsigned int>
base_type
;
85
86
//! Constructor
87
//! \param kernel kernel function to use for training and prediction
88
//! \param C regularization parameter - always the 'true' value of C, even when unconstrained is set
89
//! \param offset whether to train offset/bias parameter
90
//! \param unconstrained when a C-value is given via setParameter, should it be piped through the exp-function before using it in the solver?
91
McSvmOVATrainer
(KernelType*
kernel
,
double
C
,
bool
offset,
bool
unconstrained =
false
)
92
: base_type(kernel, C, offset, unconstrained)
93
{ }
94
95
/// \brief From INameable: return the class name.
96
std::string
name
()
const
97
{
return
"McSvmOVATrainer"
; }
98
99
/// \brief Train a kernelized SVM.
100
void
train
(
KernelClassifier<InputType>
& svm,
const
LabeledData<InputType, unsigned int>
& dataset)
101
{
102
std::size_t classes =
numberOfClasses
(dataset);
103
svm.
decisionFunction
().setStructure(this->
m_kernel
,dataset.
inputs
(),this->
m_trainOffset
,classes);
104
105
base_type::m_solutionproperties
.type =
QpNone
;
106
base_type::m_solutionproperties
.
accuracy
= 0.0;
107
base_type::m_solutionproperties
.
iterations
= 0;
108
base_type::m_solutionproperties
.
value
= 0.0;
109
base_type::m_solutionproperties
.
seconds
= 0.0;
110
for
(
unsigned
int
c=0; c<classes; c++)
111
{
112
LabeledData<InputType, unsigned int>
bindata =
oneVersusRestProblem
(dataset, c);
113
KernelClassifier<InputType>
binsvm;
114
// TODO: maybe build the quadratic programs directly,
115
// in order to profit from cached and
116
// in particular from precomputed kernel
117
// entries!
118
CSvmTrainer<InputType, QpFloatType>
bintrainer(
base_type::m_kernel
, this->
C
(),this->
m_trainOffset
);
119
bintrainer.
setCacheSize
(
base_type::m_cacheSize
);
120
bintrainer.
sparsify
() =
false
;
121
bintrainer.
stoppingCondition
() =
base_type::stoppingCondition
();
122
bintrainer.
precomputeKernel
() =
base_type::precomputeKernel
();
// sub-optimal!
123
bintrainer.
shrinking
() =
base_type::shrinking
();
124
bintrainer.
s2do
() =
base_type::s2do
();
125
bintrainer.
verbosity
() =
base_type::verbosity
();
126
bintrainer.
train
(binsvm, bindata);
127
base_type::m_solutionproperties
.
iterations
+= bintrainer.
solutionProperties
().iterations;
128
base_type::m_solutionproperties
.
seconds
+= bintrainer.
solutionProperties
().seconds;
129
base_type::m_solutionproperties
.
accuracy
=
std::max
(
base_type::solutionProperties
().accuracy, bintrainer.
solutionProperties
().accuracy);
130
column
(svm.
decisionFunction
().alpha(), c) =
column
(binsvm.
decisionFunction
().alpha(), 0);
131
if
(this->
m_trainOffset
)
132
svm.
decisionFunction
().offset(c) = binsvm.
decisionFunction
().offset(0);
133
base_type::m_accessCount
+= bintrainer.
accessCount
();
134
}
135
136
if
(
base_type::sparsify
())
137
svm.
decisionFunction
().sparsify();
138
}
139
};
140
141
142
template
<
class
InputType>
143
class
LinearMcSvmOVATrainer
:
public
AbstractLinearSvmTrainer
<InputType>
144
{
145
public
:
146
typedef
AbstractLinearSvmTrainer<InputType>
base_type
;
147
148
LinearMcSvmOVATrainer
(
double
C
,
bool
unconstrained =
false
)
149
:
AbstractLinearSvmTrainer
<
InputType
>(C, unconstrained){ }
150
151
/// \brief From INameable: return the class name.
152
std::string
name
()
const
153
{
return
"LinearMcSvmOVATrainer"
; }
154
155
void
train
(
LinearClassifier<InputType>
& model,
const
LabeledData<InputType, unsigned int>
& dataset)
156
{
157
base_type::m_solutionproperties
.type =
QpNone
;
158
base_type::m_solutionproperties
.
accuracy
= 0.0;
159
base_type::m_solutionproperties
.
iterations
= 0;
160
base_type::m_solutionproperties
.
value
= 0.0;
161
base_type::m_solutionproperties
.
seconds
= 0.0;
162
163
std::size_t dim =
inputDimension
(dataset);
164
std::size_t classes =
numberOfClasses
(dataset);
165
RealMatrix
w
(classes, dim);
166
for
(
unsigned
int
c=0; c<classes; c++)
167
{
168
LabeledData<InputType, unsigned int>
bindata =
oneVersusRestProblem
(dataset, c);
169
QpBoxLinear<InputType>
solver(bindata, dim);
170
QpSolutionProperties
prop;
171
row
(w, c) = solver.
solve
(this->
C
(), 0.0,
base_type::m_stoppingcondition
, &prop,
base_type::m_verbosity
> 0);
172
base_type::m_solutionproperties
.
iterations
+= prop.
iterations
;
173
base_type::m_solutionproperties
.
seconds
+= prop.
seconds
;
174
base_type::m_solutionproperties
.
accuracy
=
std::max
(
base_type::solutionProperties
().accuracy, prop.
accuracy
);
175
}
176
model.
decisionFunction
().setStructure(w);
177
}
178
};
179
180
181
}
182
#endif