35 #ifndef SHARK_MODELS_KERNELS_SUBRANGE_KERNEL_H 36 #define SHARK_MODELS_KERNELS_SUBRANGE_KERNEL_H 44 template<
class InputType>
45 class SubrangeKernelWrapper :
public AbstractKernelFunction<InputType>{
47 typedef AbstractKernelFunction<InputType> base_type;
53 SubrangeKernelWrapper(AbstractKernelFunction<InputType>* kernel,std::size_t start, std::size_t end)
54 :m_kernel(kernel),m_start(start),m_end(end){
55 if(kernel->hasFirstParameterDerivative())
57 if(kernel->hasFirstInputDerivative())
62 std::string
name()
const 63 {
return "SubrangeKernelWrapper"; }
66 return m_kernel->parameterVector();
70 m_kernel->setParameterVector(newParameters);
74 return m_kernel->numberOfParameters();
79 return m_kernel->createState();
82 double eval(ConstInputReference x1, ConstInputReference x2)
const{
86 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result, State& state)
const{
87 m_kernel->eval(
columns(batchX1,m_start,m_end),
columns(batchX2,m_start,m_end),result,state);
90 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result)
const{
91 m_kernel->eval(
columns(batchX1,m_start,m_end),
columns(batchX2,m_start,m_end),result);
95 ConstBatchInputReference batchX1,
96 ConstBatchInputReference batchX2,
97 RealMatrix
const& coefficients,
101 m_kernel->weightedParameterDerivative(
102 columns(batchX1,m_start,m_end),
103 columns(batchX2,m_start,m_end),
110 ConstBatchInputReference batchX1,
111 ConstBatchInputReference batchX2,
112 RealMatrix
const& coefficientsX2,
114 BatchInputType& gradient
116 BatchInputType temp(gradient.size1(),m_end-m_start);
117 m_kernel->weightedInputDerivative(
118 columns(batchX1,m_start,m_end),
119 columns(batchX2,m_start,m_end),
124 ensure_size(gradient,batchX1.size1(),batchX2.size2());
136 AbstractKernelFunction<InputType>* m_kernel;
141 template<
class InputType>
142 class SubrangeKernelBase
146 template<
class Kernels,
class Ranges>
147 SubrangeKernelBase(Kernels
const& kernels, Ranges
const& ranges){
149 for(std::size_t i = 0; i != kernels.size(); ++i){
150 m_kernelWrappers.push_back(
151 SubrangeKernelWrapper<InputType>(
get(kernels,i),
get(ranges,i).first,
get(ranges,i).second)
156 std::vector<AbstractKernelFunction<InputType>* > makeKernelVector(){
157 std::vector<AbstractKernelFunction<InputType>* > kernels(m_kernelWrappers.size());
158 for(std::size_t i = 0; i != m_kernelWrappers.size(); ++i)
159 kernels[i] = & m_kernelWrappers[i];
163 std::vector<SubrangeKernelWrapper <InputType> > m_kernelWrappers;
188 template<
class InputType>
190 :
private detail::SubrangeKernelBase<InputType>
194 typedef detail::SubrangeKernelBase<InputType> base_type1;
200 {
return "SubrangeKernel"; }
202 template<
class Kernels,
class Ranges>
204 : base_type1(kernels,ranges)
205 , base_type2(base_type1::makeKernelVector()){}