McReinforcedSvmTrainer.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Trainer for the Reinforced Multi-class Support Vector Machine
6  *
7  *
8  *
9  *
10  * \author T. Glasmachers
11  * \date 2014
12  *
13  *
14  * \par Copyright 1995-2015 Shark Development Team
15  *
16  * <BR><HR>
17  * This file is part of Shark.
18  * <http://image.diku.dk/shark/>
19  *
20  * Shark is free software: you can redistribute it and/or modify
21  * it under the terms of the GNU Lesser General Public License as published
22  * by the Free Software Foundation, either version 3 of the License, or
23  * (at your option) any later version.
24  *
25  * Shark is distributed in the hope that it will be useful,
26  * but WITHOUT ANY WARRANTY; without even the implied warranty of
27  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28  * GNU Lesser General Public License for more details.
29  *
30  * You should have received a copy of the GNU Lesser General Public License
31  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
32  *
33  */
34 //===========================================================================
35 
36 
37 #ifndef SHARK_ALGORITHMS_MCREINFORCEDSVMTRAINER_H
38 #define SHARK_ALGORITHMS_MCREINFORCEDSVMTRAINER_H
39 
40 
43 
47 
48 
49 namespace shark {
50 
51 
52 ///
53 /// \brief Training of reinforced-SVM for multi-category classification.
54 ///
55 /// The reinforces SVM was introduced in the article<br/>
56 /// <p>Reinforced multicategory support vector machines. Liu, Yufeng, and Ming Yuan. Journal of Computational and Graphical Statistics 20.4 (2011): 901-919.</p>
57 /// Its loss function has a parameter gamma which is fixed in this
58 /// implementation to its default value of 1/2.
59 ///
60 template <class InputType, class CacheType = float>
61 class McReinforcedSvmTrainer : public AbstractSvmTrainer<InputType, unsigned int>
62 {
63 public:
64 
65  typedef CacheType QpFloatType;
69 
70  //! Constructor
71  //! \param kernel kernel function to use for training and prediction
72  //! \param C regularization parameter - always the 'true' value of C, even when unconstrained is set
73  //! \param unconstrained when a C-value is given via setParameter, should it be piped through the exp-function before using it in the solver?
74  McReinforcedSvmTrainer(KernelType* kernel, double C, bool unconstrained = false)
75  : base_type(kernel, C, unconstrained)
76  { }
77 
78  /// \brief From INameable: return the class name.
79  std::string name() const
80  { return "McReinforcedSvmTrainer"; }
81 
83  {
84  std::size_t ic = dataset.numberOfElements();
85  std::size_t classes = numberOfClasses(dataset);
86 
87  // prepare the problem description
88  RealMatrix linear(ic, classes, 1.0);
89  {
90  typename LabeledData<InputType, unsigned int>::LabelContainer const& labels = dataset.labels();
91  std::size_t i=0;
92  for (std::size_t b=0; b<labels.numberOfBatches(); b++)
93  {
95  for (std::size_t e=0; e<boost::size(batch); e++)
96  {
97  unsigned int const& l = shark::get(batch, e);
98  linear(i, l) = classes - 1.0; // self-margin target value of reinforced SVM loss
99  i++;
100  }
101  }
102  }
103 
104  QpSparseArray<QpFloatType> nu(classes*classes, classes, classes*classes);
105  for (unsigned int r=0, y=0; y<classes; y++)
106  {
107  for (unsigned int p=0; p<classes; p++, r++)
108  {
109  nu.add(r, p, (QpFloatType)((p == y) ? 1.0 : -1.0));
110  }
111  }
112 
113  QpSparseArray<QpFloatType> M(classes * classes * classes, classes, 2 * classes * classes * classes);
114  QpFloatType c_ne = (QpFloatType)(-1.0 / (double)classes);
115  QpFloatType c_eq = (QpFloatType)1.0 + c_ne;
116  for (unsigned int r=0, yv=0; yv<classes; yv++)
117  {
118  for (unsigned int pv=0; pv<classes; pv++)
119  {
120  QpFloatType sign = QpFloatType((yv == pv) ? -1 : 1);//cast to keep MSVC happy...
121  for (unsigned int yw=0; yw<classes; yw++, r++)
122  {
123  M.setDefaultValue(r, sign * c_ne);
124  if (yw == pv)
125  {
126  M.add(r, pv, -sign * c_eq);
127  }
128  else
129  {
130  M.add(r, pv, sign * c_eq);
131  M.add(r, yw, -sign * c_ne);
132  }
133  }
134  }
135  }
136 
137  typedef KernelMatrix<InputType, QpFloatType> KernelMatrixType;
138  typedef CachedMatrix< KernelMatrixType > CachedMatrixType;
139  typedef PrecomputedMatrix< KernelMatrixType > PrecomputedMatrixType;
140 
141  KernelMatrixType km(*base_type::m_kernel, dataset.inputs());
142 
143  RealMatrix alpha(ic,classes,0.0);
144  RealVector bias(classes,0.0);
145  // solve the problem
147  {
148  PrecomputedMatrixType matrix(&km);
149  QpMcBoxDecomp< PrecomputedMatrixType> problem(matrix, M, dataset.labels(), linear, this->C());
151  problem.setShrinking(base_type::m_shrinking);
152  if(this->m_trainOffset){
153  BiasSolver<PrecomputedMatrixType> biasSolver(&problem);
154  biasSolver.solve(bias,base_type::m_stoppingcondition,nu);
155  }
156  else{
158  solver.solve( base_type::m_stoppingcondition, &prop);
159  }
160  alpha = problem.solution();
161  }
162  else
163  {
164  CachedMatrixType matrix(&km, base_type::m_cacheSize);
165  QpMcBoxDecomp< CachedMatrixType> problem(matrix, M, dataset.labels(), linear, this->C());
167  problem.setShrinking(base_type::m_shrinking);
168  if(this->m_trainOffset){
169  BiasSolver<CachedMatrixType> biasSolver(&problem);
170  biasSolver.solve(bias,base_type::m_stoppingcondition,nu);
171  }
172  else{
174  solver.solve( base_type::m_stoppingcondition, &prop);
175  }
176  alpha = problem.solution();
177  }
178 
179  svm.decisionFunction().setStructure(this->m_kernel,dataset.inputs(),this->m_trainOffset,classes);
180 
181  // write the solution into the model
182  for (std::size_t i=0; i<ic; i++)
183  {
184  unsigned int y = dataset.element(i).label;
185  for (unsigned int c=0; c<classes; c++)
186  {
187  double sum = 0.0;
188  unsigned int r = classes * y;
189  for (unsigned int p=0; p<classes; p++, r++)
190  sum += nu(r, c) * alpha(i, p);
191  svm.decisionFunction().alpha(i,c) = sum;
192  }
193  }
194  if (this->m_trainOffset)
195  svm.decisionFunction().offset() = bias;
196 
197  base_type::m_accessCount = km.getAccessCount();
198  if (this->sparsify())
199  svm.decisionFunction().sparsify();
200  }
201 };
202 
203 
204 template <class InputType>
206 {
207 public:
209 
210  LinearMcSvmReinforcedTrainer(double C, bool unconstrained = false)
211  : AbstractLinearSvmTrainer<InputType>(C, unconstrained){ }
212 
213  /// \brief From INameable: return the class name.
214  std::string name() const
215  { return "LinearMcSvmReinforcedTrainer"; }
216 
218  {
219  std::size_t dim = inputDimension(dataset);
220  std::size_t classes = numberOfClasses(dataset);
221 
222  QpMcLinearReinforced<InputType> solver(dataset, dim, classes);
223  RealMatrix w = solver.solve(this->C(), this->stoppingCondition(), &this->solutionProperties(), this->verbosity() > 0);
224  model.decisionFunction().setStructure(w);
225  }
226 };
227 
228 
229 }
230 #endif