McSvmLLWTrainer.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Trainer for the Multi-class Support Vector Machine by Lee, Lin, and Wahba
6  *
7  *
8  *
9  *
10  * \author T. Glasmachers
11  * \date -
12  *
13  *
14  * \par Copyright 1995-2015 Shark Development Team
15  *
16  * <BR><HR>
17  * This file is part of Shark.
18  * <http://image.diku.dk/shark/>
19  *
20  * Shark is free software: you can redistribute it and/or modify
21  * it under the terms of the GNU Lesser General Public License as published
22  * by the Free Software Foundation, either version 3 of the License, or
23  * (at your option) any later version.
24  *
25  * Shark is distributed in the hope that it will be useful,
26  * but WITHOUT ANY WARRANTY; without even the implied warranty of
27  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28  * GNU Lesser General Public License for more details.
29  *
30  * You should have received a copy of the GNU Lesser General Public License
31  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
32  *
33  */
34 //===========================================================================
35 
36 
37 #ifndef SHARK_ALGORITHMS_MCSVMLLWTRAINER_H
38 #define SHARK_ALGORITHMS_MCSVMLLWTRAINER_H
39 
40 
44 
48 
49 namespace shark {
50 
51 
52 ///
53 /// \brief Training of the multi-category SVM by Lee, Lin and Wahba (LLW).
54 ///
55 /// This is a special support vector machine variant for
56 /// classification of more than two classes. Given are data
57 /// tuples \f$ (x_i, y_i) \f$ with x-component denoting input
58 /// and y-component denoting the label 1, ..., d (see the tutorial on
59 /// label conventions; the implementation uses values 0 to d-1),
60 /// a kernel function k(x, x') and a regularization
61 /// constant C > 0. Let H denote the kernel induced
62 /// reproducing kernel Hilbert space of k, and let \f$ \phi \f$
63 /// denote the corresponding feature map.
64 /// Then the SVM classifier is the function
65 /// \f[
66 /// h(x) = \arg \max (f_c(x))
67 /// \f]
68 /// \f[
69 /// f_c(x) = \langle w_c, \phi(x) \rangle + b_c
70 /// \f]
71 /// \f[
72 /// f = (f_1, \dots, f_d)
73 /// \f]
74 /// with class-wise coefficients w_c and b_c given by the
75 /// (primal) optimization problem
76 /// \f[
77 /// \min \frac{1}{2} \sum_c \|w_c\|^2 + C \sum_i L(y_i, f(x_i))
78 /// \f]
79 /// \f[
80 /// \text{s.t. } \sum_c f_c = 0
81 /// \f]
82 /// The special property of the so-called LLW-machine is its
83 /// loss function, which arises from the application of the
84 /// discriminative sum operator to absolute margin violations.
85 /// Let \f$ h(m) = \max\{0, 1-m\} \f$ denote the hinge loss
86 /// as a function of the margin m, then the LLW loss is given
87 /// by
88 /// \f[
89 /// L(y, f(x)) = \sum_{c \not= y} h(-f_c(x))
90 /// \f]
91 ///
92 /// For more details refer to the paper:<br/>
93 /// <p>Multicategory Support Vector Machines: Theory and Application to the Classification of Microarray %Data and Satellite Radiance %Data. Y. Lee, Y. Lin, and G. Wahba. Journal of the American Statistical Association 99(465), 2004.</p>
94 ///
95 template <class InputType, class CacheType = float>
96 class McSvmLLWTrainer : public AbstractSvmTrainer<InputType, unsigned int>
97 {
98 public:
99 
100  typedef CacheType QpFloatType;
101 
105 
106  //! Constructor
107  //! \param kernel kernel function to use for training and prediction
108  //! \param C regularization parameter - always the 'true' value of C, even when unconstrained is set
109  //! \param unconstrained when a C-value is given via setParameter, should it be piped through the exp-function before using it in the solver?
110  McSvmLLWTrainer(KernelType* kernel, double C, bool unconstrained = false)
111  : base_type(kernel, C, unconstrained)
112  { }
113 
114  /// \brief From INameable: return the class name.
115  std::string name() const
116  { return "McSvmLLWTrainer"; }
117 
119  {
120  std::size_t ic = dataset.numberOfElements();
121  std::size_t classes = numberOfClasses(dataset);
122 
123  RealMatrix linear(ic, classes-1,1.0);
124  UIntVector rho(classes-1);
125  for (unsigned int p=0; p<classes-1; p++)
126  rho(p) = p;
127 
128  QpSparseArray<QpFloatType> nu(classes * (classes-1), classes, classes*(classes-1));
129  for (unsigned int r=0, y=0; y<classes; y++)
130  {
131  for (unsigned int p=0, pp=0; p<classes-1; p++, pp++, r++)
132  {
133  if (pp == y) pp++;
134  nu.add(r, pp, (QpFloatType)-1.0);
135  }
136  }
137 
138  QpSparseArray<QpFloatType> M(classes * (classes-1) * classes, classes-1, classes * (classes-1) * (classes-1));
139  QpFloatType mood = (QpFloatType)(-1.0 / (double)classes);
140  QpFloatType val = (QpFloatType)1.0 + mood;
141  for (unsigned int r=0, yv=0; yv<classes; yv++)
142  {
143  for (unsigned int pv=0, ppv=0; pv<classes-1; pv++, ppv++)
144  {
145  if (ppv == yv) ppv++;
146  for (unsigned int yw=0; yw<classes; yw++, r++)
147  {
148  M.setDefaultValue(r, mood);
149  if (ppv != yw)
150  {
151  unsigned int pw = ppv - (ppv > yw ? 1 : 0);
152  M.add(r, pw, val);
153  }
154  }
155  }
156  }
157 
158  typedef KernelMatrix<InputType, QpFloatType> KernelMatrixType;
159  typedef CachedMatrix< KernelMatrixType > CachedMatrixType;
160  typedef PrecomputedMatrix< KernelMatrixType > PrecomputedMatrixType;
161 
162  // solve the problem
163  RealMatrix alpha(ic,classes-1);
164  RealVector bias(classes,0);
165  KernelMatrixType km(*base_type::m_kernel, dataset.inputs());
167  {
168  PrecomputedMatrixType matrix(&km);
169  QpMcBoxDecomp< PrecomputedMatrixType > problem(matrix, M, dataset.labels(), linear, this->C());
171  problem.setShrinking(base_type::m_shrinking);
172  if(this->m_trainOffset){
173  BiasSolver< PrecomputedMatrixType > biasSolver(&problem);
174  biasSolver.solve(bias,base_type::m_stoppingcondition,nu);
175  }
176  else{
178  solver.solve( base_type::m_stoppingcondition, &prop);
179  }
180  alpha = problem.solution();
181  }
182  else
183  {
184  CachedMatrixType matrix(&km, base_type::m_cacheSize);
185  QpMcBoxDecomp< CachedMatrixType> problem(matrix, M, dataset.labels(), linear, this->C());
187  problem.setShrinking(base_type::m_shrinking);
188  if(this->m_trainOffset){
189  BiasSolver<CachedMatrixType> biasSolver(&problem);
190  biasSolver.solve(bias,base_type::m_stoppingcondition,nu);
191  }
192  else{
194  solver.solve( base_type::m_stoppingcondition, &prop);
195  }
196  alpha = problem.solution();
197  }
198 
199  svm.decisionFunction().setStructure(this->m_kernel,dataset.inputs(),this->m_trainOffset,classes);
200 
201  // write the solution into the model
202  for (std::size_t i=0; i<ic; i++)
203  {
204  unsigned int y = dataset.element(i).label;
205  for (std::size_t c=0; c<classes; c++)
206  {
207  double sum = 0.0;
208  unsigned int r = (classes-1) * y;
209  for (std::size_t p=0; p<classes-1; p++, r++)
210  sum += nu(r, c) * alpha(i,p);
211  svm.decisionFunction().alpha(i,c) = sum;
212  }
213  }
214  if (this->m_trainOffset)
215  svm.decisionFunction().offset() = bias;
216 
217  base_type::m_accessCount = km.getAccessCount();
218  if (this->sparsify())
219  svm.decisionFunction().sparsify();
220  }
221 };
222 
223 
224 template <class InputType>
226 {
227 public:
229 
230  LinearMcSvmLLWTrainer(double C, bool unconstrained = false)
231  : AbstractLinearSvmTrainer<InputType>(C, unconstrained){ }
232 
233  /// \brief From INameable: return the class name.
234  std::string name() const
235  { return "LinearMcSvmLLWTrainer"; }
236 
238  {
239  std::size_t dim = inputDimension(dataset);
240  std::size_t classes = numberOfClasses(dataset);
241 
242  QpMcLinearLLW<InputType> solver(dataset, dim, classes);
243  RealMatrix w = solver.solve(this->C(), this->stoppingCondition(), &this->solutionProperties(), this->verbosity() > 0);
244  model.decisionFunction().setStructure(w);
245  }
246 };
247 
248 
249 // shorthands for unified naming scheme; we resort to #define
250 // statements since old c++ does not support templated typedefs
251 #define McSvmADSTrainer McSvmLLWTrainer
252 #define LinearMcSvmADSTrainer LinearMcSvmLLWTrainer
253 
254 
255 }
256 #endif