SoftNearestNeighborClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Soft/probabilistic nearest neighbor classifier for vector-valued data.
6  *
7  *
8  *
9  * \author T. Glasmachers, C. Igel, O.Krause
10  * \date 2012-2014
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_SOFTNEARESTNEIGHBOR_H
36 #define SHARK_MODELS_SOFTNEARESTNEIGHBOR_H
37 
38 
41 
42 namespace shark {
43 
44 /// \brief SoftNearestNeighborClassifier returns a probabilistic
45 /// classification by looking at the k nearest neighbors.
46 ///
47 /// For a given number C of classes, which has to be specified a
48 /// priori, a C-dimensional real-valued vector is returned for each
49 /// query point. Each component corresponds to a class and contains
50 /// the fraction of neighbors among the K nearest neighbors that
51 /// belong to the particular class.
52 ///
53 template <class InputType>
54 class SoftNearestNeighborClassifier : public AbstractModel<InputType, RealVector>
55 {
56 public:
61 
62  /// \brief Type of distance-based weights.
64  {
65  UNIFORM, ///< uniform (= no) distance-based weights
66  ONE_OVER_DISTANCE, ///< weight each neighbor's label with 1/distance
67  };
68 
69  /// \brief Constructor
70  ///
71  /// \param algorithm the used algorithm for nearest neighbor search
72  /// \param neighbors number of neighbors
73  SoftNearestNeighborClassifier(NearestNeighbors const* algorithm, unsigned int neighbors = 3)
74  : m_algorithm(algorithm)
75  , m_classes(numberOfClasses(algorithm->dataset()))
78  { }
79 
80  /// \brief Constructor
81  ///
82  /// \param algorithm the used algorithm for nearest neighbor search
83  /// \param numClasses number of classes (given explicitly, not derived from the training data)
84  /// \param neighbors number of neighbors
85  SoftNearestNeighborClassifier(NearestNeighbors const* algorithm, std::size_t numClasses, unsigned int neighbors)
86  : m_algorithm(algorithm)
87  , m_classes(numClasses)
88  , m_neighbors(neighbors)
90  { }
91 
92  /// \brief From INameable: return the class name.
93  std::string name() const
94  { return "SoftNearestNeighborClassifier"; }
95 
96 
97  /// return the number of neighbors
98  unsigned int neighbors() const{
99  return m_neighbors;
100  }
101 
102  /// set the number of neighbors
103  void setNeighbors(unsigned int neighbors){
105  }
106 
107  /// query the way distances enter as weights
109  { return m_distanceWeights; }
110 
111  /// set the way distances enter as weights
113  { m_distanceWeights = dw; }
114 
115  /// get internal parameters of the model
116  virtual RealVector parameterVector() const{
117  RealVector parameters(1);
118  parameters(0) = m_neighbors;
119  return parameters;
120  }
121 
122  /// set internal parameters of the model
123  virtual void setParameterVector(RealVector const& newParameters){
124  SHARK_CHECK(newParameters.size() == 1,
125  "[SoftNearestNeighborClassifier::setParameterVector] invalid number of parameters");
126  //~ SHARK_CHECK((unsigned int)newParameters(0) == newParameters(0) && newParameters(0) >= 1.0,
127  //~ "[SoftNearestNeighborClassifier::setParameterVector] invalid number of neighbors");
128  m_neighbors = (unsigned int)newParameters(0);
129  }
130 
131  /// return the size of the parameter vector
132  virtual std::size_t numberOfParameters() const{
133  return 1;
134  }
135 
136  boost::shared_ptr<State> createState()const{
137  return boost::shared_ptr<State>(new EmptyState());
138  }
139 
140  /// soft k-nearest-neighbor prediction
141  void eval(BatchInputType const& patterns, BatchOutputType& outputs) const {
142  std::size_t numPatterns = shark::size(patterns);
143  std::vector<typename NearestNeighbors::DistancePair> neighbors = m_algorithm->getNeighbors(patterns, m_neighbors);
144 
145  outputs.resize(numPatterns, m_classes);
146  outputs.clear();
147 
148  for(std::size_t p = 0; p != numPatterns;++p)
149  {
150  double wsum = 0.0;
151  for ( std::size_t k = 0; k != m_neighbors; ++k)
152  {
153  double w;
154  if (m_distanceWeights == UNIFORM) w = 1.0;
155  else
156  {
157  double d = neighbors[p*m_neighbors+k].key;
158  if (d < 1e-100) w = 1e100;
159  else w = 1.0 / d;
160  }
161 
162  outputs(p, neighbors[p*m_neighbors+k].value) += w;
163  wsum += w;
164  }
165  row(outputs, p) *= (1.0 / wsum);
166  }
167  }
168  void eval(BatchInputType const& patterns, BatchOutputType& outputs, State & state)const{
169  eval(patterns, outputs);
170  }
171 
172  using base_type::eval;
173 
174  /// from ISerializable, reads a model from an archive
175  void read(InArchive& archive){
176  archive & m_neighbors;
177  archive & m_classes;
178  }
179 
180  /// from ISerializable, writes a model to an archive
181  void write(OutArchive& archive) const{
182  archive & m_neighbors;
183  archive & m_classes;
184  }
185 
186 protected:
187  NearestNeighbors const* m_algorithm;
188 
189  /// number of classes
190  std::size_t m_classes;
191 
192  /// number of neighbors to be taken into account
193  unsigned int m_neighbors;
194 
195  /// type of distance-based weights computation
197 };
198 
199 
200 }
201 #endif