RBFLayer.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief Implements a radial basis function layer.
5  *
6  *
7  *
8  * \author O. Krause
9  * \date 2014
10  *
11  *
12  * \par Copyright 1995-2015 Shark Development Team
13  *
14  * <BR><HR>
15  * This file is part of Shark.
16  * <http://image.diku.dk/shark/>
17  *
18  * Shark is free software: you can redistribute it and/or modify
19  * it under the terms of the GNU Lesser General Public License as published
20  * by the Free Software Foundation, either version 3 of the License, or
21  * (at your option) any later version.
22  *
23  * Shark is distributed in the hope that it will be useful,
24  * but WITHOUT ANY WARRANTY; without even the implied warranty of
25  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26  * GNU Lesser General Public License for more details.
27  *
28  * You should have received a copy of the GNU Lesser General Public License
29  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30  *
31  */
32 #ifndef SHARK_MODELS_RBFLayer_H
33 #define SHARK_MODELS_RBFLayer_H
34 
35 #include <shark/Core/DLLSupport.h>
37 #include <boost/math/constants/constants.hpp>
38 namespace shark {
39 
40 /// \brief Implements a layer of radial basis functions in a neural network.
41 ///
42 /// A Radial basis function layer as modeled in shark is a set of N
43 /// Gaussian distributions \f$ p(x|i) \f$.
44 /// \f[
45 /// p(x|i) = e^{\gamma_i*\|x-m_i\|^2}
46 /// \f]
47 /// and the layer transforms an input x to a vector \f$(p(x|1),\dots,p(x|N)\f$.
48 /// The \f$\gamma_i\f$ govern the width of the Gaussians, while the
49 /// vectors \f$ m_i \f$ set the centers of every Gaussian distribution.
50 ///
51 /// RBF networks profit much from good guesses on the centers and
52 /// kernel function parameters. In case of a Gaussian kernel a call
53 /// to k-Means or the EM-algorithm can be used to get a good
54 /// initialisation for the network.
55 class RBFLayer : public AbstractModel<RealVector,RealVector>
56 {
57 private:
58  struct InternalState: public State{
59  RealMatrix norm2;
60  RealMatrix p;
61 
62  void resize(std::size_t numPatterns, std::size_t numNeurons){
63  norm2.resize(numPatterns,numNeurons);
64  p.resize(numPatterns,numNeurons);
65  }
66  };
67 
68 public:
69  /// \brief Creates an empty Radial Basis Function layer.
71 
72  /// \brief Creates a layer of a Radial Basis Function Network.
73  ///
74  /// This method creates a Radial Basis Function Network (RBFN) with
75  /// \em numInput input neurons and \em numOutput output neurons.
76  ///
77  /// \param numInput Number of input neurons, equal to dimensionality of
78  /// input space.
79  /// \param numOutput Number of output neurons, equal to dimensionality of
80  /// output space and number of gaussian distributions
81  SHARK_EXPORT_SYMBOL RBFLayer(std::size_t numInput, std::size_t numOutput);
82 
83  /// \brief From INameable: return the class name.
84  std::string name() const
85  { return "RBFLayer"; }
86 
87  ///\brief Returns the current parameter vector. The amount and order of weights depend on the training parameters.
88  ///
89  ///The format of the parameter vector is \f$ (m_1,\dots,m_k,\log(\gamma_1),\dots,\log(\gamma_k))\f$
90  ///if training of one or more parameters is deactivated, they are removed from the parameter vector
91  SHARK_EXPORT_SYMBOL RealVector parameterVector()const;
92 
93  ///\brief Sets the new internal parameters.
94  SHARK_EXPORT_SYMBOL void setParameterVector(RealVector const& newParameters);
95 
96  ///\brief Returns the number of parameters which are currently enabled for training.
97  SHARK_EXPORT_SYMBOL std::size_t numberOfParameters()const;
98 
99  ///\brief Returns the number of input neurons.
100  std::size_t inputSize()const{
101  return m_centers.size2();
102  }
103 
104  ///\brief Returns the number of output neurons.
105  std::size_t outputSize()const{
106  return m_centers.size1();
107  }
108 
109  boost::shared_ptr<State> createState()const{
110  return boost::shared_ptr<State>(new InternalState());
111  }
112 
113 
114  /// \brief Configures a Radial Basis Function Network.
115  ///
116  /// This method initializes the structure of the Radial Basis Function Network (RBFN) with
117  /// \em numInput input neurons, \em numOutput output neurons and \em numHidden
118  /// hidden neurons.
119  ///
120  /// \param numInput Number of input neurons, equal to dimensionality of
121  /// input space.
122  /// \param numOutput Number of output neurons (basis functions), equal to dimensionality of
123  /// output space.
124  SHARK_EXPORT_SYMBOL void setStructure(std::size_t numInput, std::size_t numOutput);
125 
126 
128  SHARK_EXPORT_SYMBOL void eval(BatchInputType const& patterns, BatchOutputType& outputs, State& state)const;
129 
130 
132  BatchInputType const& pattern, BatchOutputType const& coefficients, State const& state, RealVector& gradient
133  )const;
134 
135  ///\brief Enables or disables parameters for learning.
136  ///
137  /// \param centers whether the centers should be trained
138  /// \param width whether the distribution width should be trained
139  SHARK_EXPORT_SYMBOL void setTrainingParameters(bool centers, bool width);
140 
141  ///\brief Returns the center values of the neurons.
142  BatchInputType const& centers()const{
143  return m_centers;
144  }
145  ///\brief Sets the center values of the neurons.
147  return m_centers;
148  }
149 
150  ///\brief Returns the width parameter of the Gaussian functions
151  RealVector const& gamma()const{
152  return m_gamma;
153  }
154 
155  /// \brief sets the width parameters - the gamma values - of the distributions.
156  SHARK_EXPORT_SYMBOL void setGamma(RealVector const& gamma);
157 
158  /// From ISerializable, reads a model from an archive
159  SHARK_EXPORT_SYMBOL void read( InArchive & archive );
160 
161  /// From ISerializable, writes a model to an archive
162  SHARK_EXPORT_SYMBOL void write( OutArchive & archive ) const;
163 protected:
164  //====model parameters
165 
166  ///\brief The center points. The i-th element corresponds to the center of neuron number i
167  RealMatrix m_centers;
168 
169  ///\brief stores the width parameters of the Gaussian functions
170  RealVector m_gamma;
171 
172  /// \brief the logarithm of the normalization constant for every distribution
173  RealVector m_logNormalization;
174 
175  //=====training parameters
176  ///enables learning of the center points of the neurons
178  ///enables learning of the width parameters.
180 
181 
182 
183 };
184 }
185 
186 #endif
187