ConvexCombination.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief Implements a Model using a linear function.
5  *
6  *
7  *
8  * \author T. Glasmachers, O. Krause
9  * \date 2010-2011
10  *
11  *
12  * \par Copyright 1995-2015 Shark Development Team
13  *
14  * <BR><HR>
15  * This file is part of Shark.
16  * <http://image.diku.dk/shark/>
17  *
18  * Shark is free software: you can redistribute it and/or modify
19  * it under the terms of the GNU Lesser General Public License as published
20  * by the Free Software Foundation, either version 3 of the License, or
21  * (at your option) any later version.
22  *
23  * Shark is distributed in the hope that it will be useful,
24  * but WITHOUT ANY WARRANTY; without even the implied warranty of
25  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26  * GNU Lesser General Public License for more details.
27  *
28  * You should have received a copy of the GNU Lesser General Public License
29  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30  *
31  */
32 #ifndef SHARK_MODELS_ConvexCombination_H
33 #define SHARK_MODELS_ConvexCombination_H
34 
36 namespace shark {
37 
38 
39 ///
40 /// \brief Models a convex combination of inputs
41 ///
42 /// For a given input vector x, the convex combination returns \f$ f_i(x) = sum_j w_{ij} x_j \f$,
43 /// where \f$ w_i > 0 \f$ and \f$ sum_j w_{ij} = 1\f$, that is the outputs of
44 /// the model are a convex combination of the inputs.
45 ///
46 /// To ensure that the constraints are fulfilled, the model uses a different
47 /// set of weights q_i and \f$ w_{ij} = exp(q_{ij})/sum_j exp(q_{ik}) \f$. As usual, this
48 /// encoding is only used for the derivatives and the parameter vectors, not
49 /// when the weights are explicitely set. In the latter case, the user must provide
50 /// a set of suitable \f$ w_{ij} \f$.
51 class ConvexCombination : public AbstractModel<RealVector,RealVector>
52 {
53 private:
54  RealMatrix m_w; ///< the convex comination weights. it holds sum(row(w_i)) = 1
55 public:
56 
57  /// CDefault Constructor; use setStructure later
61  }
62 
63  /// Constructor creating a model with given dimnsionalities and optional offset term.
64  ConvexCombination(std::size_t inputs, std::size_t outputs = 1)
65  : m_w(outputs,inputs,0.0){
68  }
69 
70  /// Construction from matrix
71  ConvexCombination(RealMatrix const& matrix):m_w(matrix){
74  }
75 
76  /// \brief From INameable: return the class name.
77  std::string name() const
78  { return "ConvexCombination"; }
79 
80  ///swap
81  friend void swap(ConvexCombination& model1,ConvexCombination& model2){
82  swap(model1.m_w,model2.m_w);
83  }
84 
85  ///operator =
87  ConvexCombination tempModel(model);
88  swap(*this,tempModel);
89  return *this;
90  }
91 
92  /// obtain the input dimension
93  std::size_t inputSize() const{
94  return m_w.size2();
95  }
96 
97  /// obtain the output dimension
98  std::size_t outputSize() const{
99  return m_w.size1();
100  }
101 
102  /// obtain the parameter vector
103  RealVector parameterVector() const{
104  RealVector ret(numberOfParameters());
105  init(ret) << toVector(log(m_w));
106  return ret;
107  }
108 
109  /// overwrite the parameter vector
110  void setParameterVector(RealVector const& newParameters)
111  {
112  init(newParameters) >> toVector(m_w);
113  noalias(m_w) = exp(m_w);
114  for(std::size_t i = 0; i != outputSize(); ++i){
115  row(m_w,i) /= sum(row(m_w,i));
116  }
117  }
118 
119  /// return the number of parameter
120  std::size_t numberOfParameters() const{
121  return m_w.size1()*m_w.size2();
122  }
123 
124  /// overwrite structure and parameters
125  void setStructure(std::size_t inputs, std::size_t outputs = 1){
126  ConvexCombination model(inputs,outputs);
127  swap(*this,model);
128  }
129 
130  RealMatrix const& weights() const{
131  return m_w;
132  }
133 
134  RealMatrix& weights(){
135  return m_w;
136  }
137 
138  boost::shared_ptr<State> createState()const{
139  return boost::shared_ptr<State>(new EmptyState());
140  }
141 
142  /// Evaluate the model: output = w * input
143  void eval(BatchInputType const& inputs, BatchOutputType& outputs)const{
144  outputs.resize(inputs.size1(),m_w.size1());
145  noalias(outputs) = prod(inputs,trans(m_w));
146  }
147  /// Evaluate the model: output = w *input
148  void eval(BatchInputType const& inputs, BatchOutputType& outputs, State& state)const{
149  eval(inputs,outputs);
150  }
151 
152  ///\brief Calculates the first derivative w.r.t the parameters and summing them up over all patterns of the last computed batch
154  BatchInputType const& patterns, RealMatrix const& coefficients, State const& state, RealVector& gradient
155  )const{
156  SIZE_CHECK(coefficients.size2()==outputSize());
157  SIZE_CHECK(coefficients.size1()==patterns.size1());
158 
159  gradient.resize(numberOfParameters());
160  blas::dense_matrix_adaptor<double> weightGradient = blas::adapt_matrix(outputSize(),inputSize(),gradient.storage());
161 
162  //derivative is
163  //sum_i sum_j c_ij sum_k x_ik grad_q w_jk= sum_k sum_j grad_q w_jk (sum_i c_ij x_ik)
164  //and we set d_jk=sum_i c_ij x_ik => d = C^TX
165  RealMatrix d = prod(trans(coefficients), patterns);
166 
167  //use the same drivative as in the softmax model
168  for(std::size_t i = 0; i != outputSize(); ++i){
169  double mass=inner_prod(row(d,i),row(m_w,i));
170  noalias(row(weightGradient,i)) = element_prod(
171  row(d,i) - mass,
172  row(m_w,i)
173  );
174  }
175  }
176  ///\brief Calculates the first derivative w.r.t the inputs and summs them up over all patterns of the last computed batch
178  BatchInputType const & patterns,
179  BatchOutputType const & coefficients,
180  State const& state,
181  BatchInputType& derivative
182  )const{
183  SIZE_CHECK(coefficients.size2() == outputSize());
184  SIZE_CHECK(coefficients.size1() == patterns.size1());
185 
186  derivative.resize(patterns.size1(),inputSize());
187  noalias(derivative) = prod(coefficients,m_w);
188  }
189 
190  /// From ISerializable
191  void read(InArchive& archive){
192  archive >> m_w;
193  }
194  /// From ISerializable
195  void write(OutArchive& archive) const{
196  archive << m_w;
197  }
198 };
199 
200 
201 }
202 #endif