McSvmATMTrainer.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Trainer for the ATM Multi-class Support Vector Machine
6  *
7  *
8  *
9  *
10  * \author T. Glasmachers
11  * \date -
12  *
13  *
14  * \par Copyright 1995-2015 Shark Development Team
15  *
16  * <BR><HR>
17  * This file is part of Shark.
18  * <http://image.diku.dk/shark/>
19  *
20  * Shark is free software: you can redistribute it and/or modify
21  * it under the terms of the GNU Lesser General Public License as published
22  * by the Free Software Foundation, either version 3 of the License, or
23  * (at your option) any later version.
24  *
25  * Shark is distributed in the hope that it will be useful,
26  * but WITHOUT ANY WARRANTY; without even the implied warranty of
27  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28  * GNU Lesser General Public License for more details.
29  *
30  * You should have received a copy of the GNU Lesser General Public License
31  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
32  *
33  */
34 //===========================================================================
35 
36 
37 #ifndef SHARK_ALGORITHMS_MCSVMATMTRAINER_H
38 #define SHARK_ALGORITHMS_MCSVMATMTRAINER_H
39 
40 
44 
48 
49 
50 namespace shark {
51 
52 
53 ///
54 /// \brief Training of ATM-SVMs for multi-category classification.
55 ///
56 /// The ATM-SVM is a special support vector machine variant for
57 /// classification of more than two classes. Given are data
58 /// tuples \f$ (x_i, y_i) \f$ with x-component denoting input
59 /// and y-component denoting the label 1, ..., d (see the tutorial on
60 /// label conventions; the implementation uses values 0 to d-1),
61 /// a kernel function k(x, x') and a regularization
62 /// constant C > 0. Let H denote the kernel induced
63 /// reproducing kernel Hilbert space of k, and let \f$ \phi \f$
64 /// denote the corresponding feature map.
65 /// Then the SVM classifier is the function
66 /// \f[
67 /// h(x) = \arg \max (f_c(x))
68 /// \f]
69 /// \f[
70 /// f_c(x) = \langle w_c, \phi(x) \rangle + b_c
71 /// \f]
72 /// \f[
73 /// f = (f_1, \dots, f_d)
74 /// \f]
75 /// with class-wise coefficients w_c and b_c given by the
76 /// (primal) optimization problem
77 /// \f[
78 /// \min \frac{1}{2} \sum_c \|w_c\|^2 + C \sum_i L(y_i, f(x_i))
79 /// \f]
80 /// \f[
81 /// \text{s.t. } \sum_c f_c = 0
82 /// \f]
83 /// The special property of the so-called ATM machine is its
84 /// loss function, which arises from the application of the
85 /// total maximum operator to absolute margin violations.
86 /// Let \f$ h(m) = \max\{0, 1-m\} \f$ denote the hinge loss
87 /// as a function of the margin m, then the ATM loss is given
88 /// by
89 /// \f[
90 /// L(y, f(x)) = \max_c h((2 \cdot \delta_{c,y} - 1) \cdot f_c(x))
91 /// \f]
92 /// where the Kronecker delta is one if its arguments agree and
93 /// zero otherwise.
94 ///
95 /// For more details refer to the technical report:<br/>
96 /// <p>Fast Training of Multi-Class Support Vector Machines. &Uuml; Dogan, T. Glasmachers, and C. Igel, Technical Report 2011/3, Department of Computer Science, University of Copenhagen, 2011.</p>
97 ///
98 template <class InputType, class CacheType = float>
99 class McSvmATMTrainer : public AbstractSvmTrainer<InputType, unsigned int>
100 {
101 public:
102  typedef CacheType QpFloatType;
106 
107  //! Constructor
108  //! \param kernel kernel function to use for training and prediction
109  //! \param C regularization parameter - always the 'true' value of C, even when unconstrained is set
110  //! \param unconstrained when a C-value is given via setParameter, should it be piped through the exp-function before using it in the solver?
111  McSvmATMTrainer(KernelType* kernel, double C, bool unconstrained = false)
112  : base_type(kernel, C, unconstrained)
113  { }
114 
115  /// \brief From INameable: return the class name.
116  std::string name() const
117  { return "McSvmATMTrainer"; }
118 
120  {
121  std::size_t ic = dataset.numberOfElements();
122  std::size_t classes = numberOfClasses(dataset);
123 
124  // prepare the problem description
125  RealMatrix linear(ic, classes,1.0);
126  QpSparseArray<QpFloatType> nu(classes*classes, classes, classes*classes);
127  {
128  for (unsigned int r=0, y=0; y<classes; y++)
129  {
130  for (unsigned int p=0; p<classes; p++, r++)
131  {
132  nu.add(r, p, (QpFloatType)((p == y) ? 1.0 : -1.0));
133  }
134  }
135  }
136  QpSparseArray<QpFloatType> M(classes * classes * classes, classes, 2 * classes * classes * classes);
137  {
138  QpFloatType c_ne = (QpFloatType)(-1.0 / (double)classes);
139  QpFloatType c_eq = (QpFloatType)1.0 + c_ne;
140  for (unsigned int r=0, yv=0; yv<classes; yv++)
141  {
142  for (unsigned int pv=0; pv<classes; pv++)
143  {
144  QpFloatType sign = QpFloatType((yv == pv) ? -1 : 1);//cast to keep MSVC happy...
145  for (unsigned int yw=0; yw<classes; yw++, r++)
146  {
147  M.setDefaultValue(r, sign * c_ne);
148  if (yw == pv)
149  {
150  M.add(r, pv, -sign * c_eq);
151  }
152  else
153  {
154  M.add(r, pv, sign * c_eq);
155  M.add(r, yw, -sign * c_ne);
156  }
157  }
158  }
159  }
160  }
161 
162  typedef KernelMatrix<InputType, QpFloatType> KernelMatrixType;
163  typedef CachedMatrix< KernelMatrixType > CachedMatrixType;
164  typedef PrecomputedMatrix< KernelMatrixType > PrecomputedMatrixType;
165 
166  KernelMatrixType km(*base_type::m_kernel, dataset.inputs());
167 
168  RealMatrix alpha(ic,classes,0.0);
169  RealVector bias(classes,0.0);
170  // solve the problem
172  {
173  PrecomputedMatrixType matrix(&km);
174  QpMcSimplexDecomp< PrecomputedMatrixType> problem(matrix, M, dataset.labels(), linear, this->C());
176  problem.setShrinking(base_type::m_shrinking);
177  if(this->m_trainOffset){
178  BiasSolverSimplex<PrecomputedMatrixType> biasSolver(&problem);
179  biasSolver.solve(bias,base_type::m_stoppingcondition,nu);
180  }
181  else{
183  solver.solve( base_type::m_stoppingcondition, &prop);
184  }
185  alpha = problem.solution();
186  }
187  else
188  {
189  CachedMatrixType matrix(&km, base_type::m_cacheSize);
190  QpMcSimplexDecomp< CachedMatrixType> problem(matrix, M, dataset.labels(), linear, this->C());
192  problem.setShrinking(base_type::m_shrinking);
193  if(this->m_trainOffset){
194  BiasSolverSimplex<CachedMatrixType> biasSolver(&problem);
195  biasSolver.solve(bias,base_type::m_stoppingcondition,nu);
196  }
197  else{
199  solver.solve( base_type::m_stoppingcondition, &prop);
200  }
201  alpha = problem.solution();
202  }
203 
204  svm.decisionFunction().setStructure(this->m_kernel,dataset.inputs(),this->m_trainOffset,classes);
205 
206  // write the solution into the model
207  for (std::size_t i=0; i<ic; i++)
208  {
209  unsigned int y = dataset.element(i).label;
210  for (unsigned int c=0; c<classes; c++)
211  {
212  double sum = 0.0;
213  unsigned int r = classes * y;
214  for (unsigned int p=0; p<classes; p++, r++)
215  sum += nu(r, c) * alpha(i,p);
216  svm.decisionFunction().alpha(i,c) = sum;
217  }
218  }
219  if (this->m_trainOffset)
220  svm.decisionFunction().offset() = bias;
221 
222  base_type::m_accessCount = km.getAccessCount();
223  if (this->sparsify())
224  svm.decisionFunction().sparsify();
225  }
226 };
227 
228 
229 template <class InputType>
231 {
232 public:
234 
235  LinearMcSvmATMTrainer(double C, bool unconstrained = false)
236  : AbstractLinearSvmTrainer<InputType>(C, unconstrained){ }
237 
238  /// \brief From INameable: return the class name.
239  std::string name() const
240  { return "LinearMcSvmATMTrainer"; }
241 
243  {
244  std::size_t dim = inputDimension(dataset);
245  std::size_t classes = numberOfClasses(dataset);
246  QpMcLinearATM<InputType> solver(dataset, dim, classes);
247  RealMatrix w = solver.solve(this->C(), this->stoppingCondition(), &this->solutionProperties(), this->verbosity() > 0);
248  model.decisionFunction().setStructure(w);
249  }
250 };
251 
252 
253 }
254 #endif