KernelMeanClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief KernelMeanClassifier
6  *
7  *
8  *
9  * \author T. Glasmachers, C. Igel
10  * \date 2010, 2011
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 #ifndef SHARK_ALGORITHMS_TRAINERS_KERNELMEAN_H
35 #define SHARK_ALGORITHMS_TRAINERS_KERNELMEAN_H
36 
37 
40 #include <shark/Data/Dataset.h>
41 
42 namespace shark {
43 
44 /*! \brief Kernelized mean-classifier
45  *
46  * Computes the mean of the training data in feature space for each
47  * class and assigns a new data point to the class with the nearest
48  * mean.
49  */
50 template<class InputType>
51 class KernelMeanClassifier : public AbstractTrainer<KernelClassifier<InputType>, unsigned int>
52 {
53 public:
55 
56  std::string name() const
57  { return "KernelMeanClassifier"; }
58 
60  SHARK_CHECK(numberOfClasses(dataset) ==2, "[KernelMeanClassifier::train] not a binary class problem");
61 
62  model.decisionFunction().setStructure(mpe_kernel,dataset.inputs(),true);
63 
64  std::size_t patterns = dataset.numberOfElements();
65  std::vector<std::size_t> numClasses = classSizes(dataset);
66  double coeffs[] = {0,0};
67 
68  SHARK_CHECK(numClasses[0] > 0, "[KernelMeanClassifier::train] class 0 has no class members" );
69  SHARK_CHECK(numClasses[1] > 0, "[KernelMeanClassifier::train] class 1 has no class members" );
70 
71  coeffs[0] = 1.0 / numClasses[0];
72  coeffs[1] = -1.0 / numClasses[1];
73 
74  // compute coefficients and bias term
75  double classBias[]={0.0,0.0};
76  RealVector params(patterns + 1);
77 
78  //todo: slow implementation without batch processing!
80  std::size_t i = 0;
81  BOOST_FOREACH(ElementRef element,dataset.elements()){
82 
83  unsigned int y = element.label;
84 
85  // compute and set coefficients
86  params(i) = coeffs[y];
87  ++i;
88  // compute values to calculate bias
89  BOOST_FOREACH(ElementRef element2,dataset.elements()){
90  if (element2.label != y)
91  continue;
92  //todo: fast implementation should create batches of same class elements and process them!
93  classBias[y] += mpe_kernel->eval(element.input, element2.input);
94  }
95  }
96  // set bias
97  params(patterns) = 0.5 * (classBias[0] * sqr(coeffs[0]) - classBias[1] * sqr(coeffs[1]));
98  // pass parameters to model, note the negation
99  model.setParameterVector(-params);
100  }
101 
103 };
104 
105 
106 }
107 #endif