ProductKernel.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Product of kernel functions.
6  *
7  *
8  *
9  * \author T. Glasmachers, O.Krause
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_KERNELS_PRODUCTKERNEL_H
36 #define SHARK_MODELS_KERNELS_PRODUCTKERNEL_H
37 
38 
40 
41 namespace shark{
42 
43 
44 ///
45 /// \brief Product of kernel functions.
46 ///
47 /// \par
48 /// The product of any number of kernels is again a valid kernel.
49 /// This class supports a kernel af the form
50 /// \f$ k(x, x') = k_1(x, x') \cdot k_2(x, x') \cdot \dots \cdot k_n(x, x') \f$
51 /// for any number of base kernels. All kernels need to be defined
52 /// on the same input space.
53 ///
54 /// \par
55 /// Derivatives are currently not implemented. Only the plain
56 /// kernel value can be computed. Everyone is free to add this
57 /// functionality :)
58 ///
59 template<class InputType>
60 class ProductKernel : public AbstractKernelFunction<InputType>
61 {
62 private:
64 public:
69  /// \brief Default constructor.
71  // this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
72  // this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
73  // this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
74  // this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
75  this->m_features |= base_type::IS_NORMALIZED; // an "empty" product is a normalized kernel (k(x, x) = 1).
76  }
77 
78  /// \brief Constructor for a product of two kernels.
79  ProductKernel(SubKernel* k1, SubKernel* k2){
80  // this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
81  // this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
82  // this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
83  // this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
84  this->m_features |= base_type::IS_NORMALIZED; // an "empty" product is a normalized kernel (k(x, x) = 1).
85  addKernel(k1);
86  addKernel(k2);
87  }
88  ProductKernel(std::vector<SubKernel*> kernels){
89  // this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
90  // this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
91  // this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
92  // this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
93  this->m_features |= base_type::IS_NORMALIZED; // an "empty" product is a normalized kernel (k(x, x) = 1).
94  for(std::size_t i = 0; i != kernels.size(); ++i)
95  addKernel(kernels[i]);
96  }
97 
98  /// \brief From INameable: return the class name.
99  std::string name() const
100  { return "ProductKernel"; }
101 
102  /// \brief Add one more kernel to the expansion.
103  ///
104  /// \param k The pointer is expected to remain valid during the lifetime of the ProductKernel object.
105  ///
106  void addKernel(SubKernel* k){
107  SHARK_ASSERT(k != NULL);
108 
109  m_kernels.push_back(k);
111  if (! k->isNormalized()) this->m_features.reset(base_type::IS_NORMALIZED); // products of normalized kernels are normalized.
112  }
113 
114  RealVector parameterVector() const{
115  RealVector ret(m_numberOfParameters);
116  init(ret) << parameterSet(m_kernels);
117  return ret;
118  }
119 
120  void setParameterVector(RealVector const& newParameters){
121  SIZE_CHECK(newParameters.size() == m_numberOfParameters);
122  init(newParameters) >> parameterSet(m_kernels);
123  }
124 
125  std::size_t numberOfParameters() const{
126  return m_numberOfParameters;
127  }
128 
129  /// \brief evaluates the kernel function
130  ///
131  /// This function returns the product of all sub-kernels.
132  double eval(ConstInputReference x1, ConstInputReference x2) const{
133  double prod = 1.0;
134  for (std::size_t i=0; i<m_kernels.size(); i++)
135  prod *= m_kernels[i]->eval(x1, x2);
136  return prod;
137  }
138 
139  void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result) const{
140  std::size_t sizeX1=shark::size(batchX1);
141  std::size_t sizeX2=shark::size(batchX2);
142 
143  //evaluate first kernel to initialize the result
144  m_kernels[0]->eval(batchX1,batchX2,result);
145 
146  RealMatrix kernelResult(sizeX1,sizeX2);
147  for(std::size_t i = 1; i != m_kernels.size(); ++i){
148  m_kernels[i]->eval(batchX1,batchX2,kernelResult);
149  noalias(result) *= kernelResult;
150  }
151  }
152 
153  void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result, State& state) const{
154  eval(batchX1,batchX2,result);
155  }
156 
157  /// From ISerializable.
158  void read(InArchive& ar){
159  for(std::size_t i = 0;i != m_kernels.size(); ++i ){
160  ar >> *m_kernels[i];
161  }
162  ar >> m_numberOfParameters;
163  }
164 
165  /// From ISerializable.
166  void write(OutArchive& ar) const{
167  for(std::size_t i = 0;i != m_kernels.size(); ++i ){
168  ar << const_cast<AbstractKernelFunction<InputType> const&>(*m_kernels[i]);//prevent serialization warning
169  }
170  ar << m_numberOfParameters;
171  }
172 
173 protected:
174  std::vector<SubKernel*> m_kernels; ///< vector of sub-kernels
175  std::size_t m_numberOfParameters; ///< total number of parameters in the product (this is redundant information)
176 };
177 
178 
179 }
180 #endif