NearestNeighborClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Nearest neighbor classification
6  *
7  *
8  *
9  * \author T. Glasmachers, O. Krause
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_NEARESTNEIGHBORCLASSIFIER_H
36 #define SHARK_MODELS_NEARESTNEIGHBORCLASSIFIER_H
37 
40 #include <algorithm>
41 namespace shark {
42 
43 
44 ///
45 /// \brief Nearest Neighbor Classifier.
46 ///
47 /// \par
48 /// The NearestNeighborClassifier predicts a class label
49 /// according to a local majority decision among its k
50 /// nearest neighbors. It is not specified how ties are
51 /// broken.
52 ///
53 /// This model requires the use of one of sharks nearest neighhbor Algorithms.
54 /// \see AbstractNearestNeighbors
55 template <class InputType>
56 class NearestNeighborClassifier : public AbstractModel<InputType, unsigned int>
57 {
58 public:
63 
64  /// \brief Type of distance-based weights.
66  {
67  UNIFORM, ///< uniform (= no) distance-based weights
68  ONE_OVER_DISTANCE, ///< weight each neighbor's label with 1/distance
69  };
70 
71  ///\brief Constructor
72  ///
73  /// \param algorithm the used algorithm for nearst neighbor search
74  /// \param neighbors: number of neighbors
75  NearestNeighborClassifier(NearestNeighbors const* algorithm, std::size_t neighbors = 3)
76  : m_algorithm(algorithm)
77  , m_classes(numberOfClasses(algorithm->dataset()))
80  { }
81 
82  /// \brief From INameable: return the class name.
83  std::string name() const
84  { return "NearestNeighborClassifier"; }
85 
86 
87  /// return the number of neighbors
88  std::size_t neighbors() const{
89  return m_neighbors;
90  }
91 
92  /// set the number of neighbors
93  void setNeighbors(std::size_t neighbors){
95  }
96 
97  /// query the way distances enter as weights
99  { return m_distanceWeights; }
100 
101  /// set the way distances enter as weights
103  { m_distanceWeights = dw; }
104 
105  /// get internal parameters of the model
106  virtual RealVector parameterVector() const{
107  RealVector parameters(1);
108  parameters(0) = (double)m_neighbors;
109  return parameters;
110  }
111 
112  /// set internal parameters of the model
113  virtual void setParameterVector(RealVector const& newParameters){
114  SHARK_CHECK(newParameters.size() == 1,
115  "[SoftNearestNeighborClassifier::setParameterVector] invalid number of parameters");
116  //~ SHARK_CHECK((std::size_t)newParameters(0) == newParameters(0) && newParameters(0) >= 1.0,
117  //~ "[SoftNearestNeighborClassifier::setParameterVector] invalid number of neighbors");
118  m_neighbors = (std::size_t)newParameters(0);
119  }
120 
121  /// return the size of the parameter vector
122  virtual std::size_t numberOfParameters() const{
123  return 1;
124  }
125 
126  boost::shared_ptr<State> createState()const{
127  return boost::shared_ptr<State>(new EmptyState());
128  }
129 
130  using base_type::eval;
131 
132  /// soft k-nearest-neighbor prediction
133  void eval(BatchInputType const& patterns, BatchOutputType& output, State& state)const{
134  std::size_t numPatterns = shark::size(patterns);
135  std::vector<typename NearestNeighbors::DistancePair> neighbors = m_algorithm->getNeighbors(patterns,m_neighbors);
136 
137  output.resize(numPatterns);
138  output.clear();
139 
140  for(std::size_t p = 0; p != numPatterns;++p){
141  std::vector<double> histogram(m_classes, 0.0);
142  for ( std::size_t k = 0; k != m_neighbors; ++k){
143  if (m_distanceWeights == UNIFORM) histogram[neighbors[p*m_neighbors+k].value]++;
144  else
145  {
146  double d = neighbors[p*m_neighbors+k].key;
147  if (d < 1e-100) histogram[neighbors[p*m_neighbors+k].value] += 1e100;
148  else histogram[neighbors[p*m_neighbors+k].value] += 1.0 / d;
149  }
150  }
151  output(p) = static_cast<unsigned int>(std::max_element(histogram.begin(),histogram.end()) - histogram.begin());
152  }
153  }
154 
155  /// from ISerializable, reads a model from an archive
156  void read(InArchive& archive){
157  archive & m_neighbors;
158  archive & m_classes;
159  }
160 
161  /// from ISerializable, writes a model to an archive
162  void write(OutArchive& archive) const{
163  archive & m_neighbors;
164  archive & m_classes;
165  }
166 
167 protected:
168  NearestNeighbors const* m_algorithm;
169 
170  /// number of classes
171  std::size_t m_classes;
172 
173  /// number of neighbors to be taken into account
174  std::size_t m_neighbors;
175 
176  /// type of distance-based weights computation
178 };
179 
180 
181 
182 }
183 #endif