35 #ifndef SHARK_MODELS_KERNELS_WEIGHTED_SUM_KERNEL_H 36 #define SHARK_MODELS_KERNELS_WEIGHTED_SUM_KERNEL_H 41 #include <boost/utility/enable_if.hpp> 62 template<
class InputType=RealVector>
68 struct InternalState:
public State{
70 std::vector<RealMatrix> kernelResults;
71 std::vector<boost::shared_ptr<State> > kernelStates;
73 InternalState(std::size_t numSubKernels)
74 :kernelResults(numSubKernels),kernelStates(numSubKernels){}
76 void resize(std::size_t sizeX1, std::size_t sizeX2){
77 result.resize(sizeX1, sizeX2);
78 for(std::size_t i = 0; i != kernelResults.size(); ++i){
79 kernelResults[i].resize(sizeX1, sizeX2);
89 SHARK_CHECK( base.size() > 0,
"[WeightedSumKernel::WeightedSumKernel] There should be at least one sub-kernel.");
91 m_base.resize( base.size() );
94 for (std::size_t i=0; i !=
m_base.size() ; i++) {
96 m_base[i].kernel = base[i];
98 m_base[i].adaptive =
false;
104 for (
unsigned int i=0; i<
m_base.size(); i++ ){
105 if ( !
m_base[i].kernel->hasFirstParameterDerivative() ) {
106 hasFirstParameterDerivative =
false;
111 for (
unsigned int i=0; i<
m_base.size(); i++ ){
112 if ( !
m_base[i].kernel->hasFirstInputDerivative() ) {
113 hasFirstInputDerivative =
false;
118 if ( hasFirstParameterDerivative )
121 if ( hasFirstInputDerivative )
127 {
return "WeightedSumKernel"; }
131 return m_base[index].adaptive;
135 m_base[index].adaptive = b;
140 for (std::size_t i=0; i!=
m_base.size(); i++)
148 return m_base[index].weight;
157 std::size_t index = 0;
158 for (; index !=
m_base.size()-1; index++){
162 for (std::size_t i=0; i !=
m_base.size(); i++){
164 std::size_t n =
m_base[i].kernel->numberOfParameters();
165 subrange(ret,index,index+n) =
m_base[i].kernel->parameterVector();
174 InternalState* state =
new InternalState(
m_base.size());
175 for(std::size_t i = 0; i !=
m_base.size(); ++i){
176 state->kernelStates[i]=
m_base[i].kernel->createState();
178 return boost::shared_ptr<State>(state);
187 std::size_t index = 0;
188 for (; index !=
m_base.size()-1; index++){
189 double w = newParameters(index);
190 m_base[index+1].weight = std::exp(w);
191 m_weightsum +=
m_base[index+1].weight;
194 for (std::size_t i=0; i !=
m_base.size(); i++){
196 std::size_t n =
m_base[i].kernel->numberOfParameters();
197 m_base[i].kernel->setParameterVector(
subrange(newParameters,index,index+n));
209 double eval(ConstInputReference x1, ConstInputReference x2)
const{
210 double numerator = 0.0;
211 for (std::size_t i=0; i !=
m_base.size(); i++){
212 double result =
m_base[i].kernel->eval(x1, x2);
213 numerator +=
m_base[i].weight*result;
221 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result)
const{
224 ensure_size(result,sizeX1,sizeX2);
227 RealMatrix kernelResult(sizeX1,sizeX2);
228 for (std::size_t i = 0; i !=
m_base.size(); i++){
229 m_base[i].kernel->eval(batchX1, batchX2,kernelResult);
230 result +=
m_base[i].weight*kernelResult;
239 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result,
State& state)
const{
242 ensure_size(result,sizeX1,sizeX2);
245 InternalState& s = state.
toState<InternalState>();
246 s.resize(sizeX1,sizeX2);
248 for (std::size_t i=0; i !=
m_base.size(); i++){
249 m_base[i].kernel->eval(batchX1,batchX2,s.kernelResults[i],*s.kernelStates[i]);
250 result +=
m_base[i].weight*s.kernelResults[i];
258 ConstBatchInputReference batchX1,
259 ConstBatchInputReference batchX2,
260 RealMatrix
const& coefficients,
266 std::size_t numKernels =
m_base.size();
268 InternalState
const& s = state.
toState<InternalState>();
276 double numeratorSum =
sum(coefficients * s.result);
277 for (std::size_t i = 1; i != numKernels; i++) {
279 double summedK=
sum(coefficients * s.kernelResults[i]);
280 gradient(i-1) =
m_base[i].weight * (summedK *
m_weightsum - numeratorSum) / sumSquared;
283 std::size_t gradPos = numKernels-1;
284 RealVector kernelGrad;
285 for (std::size_t i=0; i != numKernels; i++) {
288 m_base[i].kernel->weightedParameterDerivative(batchX1,batchX2,coefficients,*s.kernelStates[i],kernelGrad);
289 std::size_t n = kernelGrad.size();
303 ConstBatchInputReference batchX1,
304 ConstBatchInputReference batchX2,
305 RealMatrix
const& coefficientsX2,
307 BatchInputType& gradient
311 weightedInputDerivativeImpl<BatchInputType>(batchX1,batchX2,coefficientsX2,state,gradient);
315 for(std::size_t i = 0;i !=
m_base.size(); ++i ){
318 ar >> *(
m_base[i].kernel);
324 for(std::size_t i=0;i!=
m_base.size();++i){
327 ar << const_cast<AbstractKernelFunction<InputType>
const&>(*(
m_base[i].kernel));
344 for (std::size_t i=0; i !=
m_base.size(); i++)
355 ConstBatchInputReference batchX1,
356 ConstBatchInputReference batchX2,
357 RealMatrix
const& coefficientsX2,
359 BatchInputType& gradient,
360 typename boost::enable_if<boost::is_same<T,RealMatrix > >::type* dummy = 0
362 std::size_t numKernels =
m_base.size();
363 InternalState
const& s = state.
toState<InternalState>();
367 m_base[0].kernel->weightedInputDerivative(batchX1, batchX2, coefficientsX2, *s.kernelStates[0], gradient);
369 BatchInputType kernelGrad;
370 for (std::size_t i=1; i != numKernels; i++){
371 m_base[i].kernel->weightedInputDerivative(batchX1, batchX2, coefficientsX2, *s.kernelStates[i], kernelGrad);
373 gradient += coeff * kernelGrad;
378 ConstBatchInputReference batchX1,
379 ConstBatchInputReference batchX2,
380 RealMatrix
const& coefficientsX2,
382 BatchInputType& gradient,
383 typename boost::disable_if<boost::is_same<T,RealMatrix > >::type* dummy = 0
385 throw SHARKEXCEPTION(
"[WeightedSumKernel::weightdInputDerivative] The used BatchInputType is no Vector");