35 #ifndef SHARK_MODELS_KERNELS_NORMALIZED_KERNEL_H 36 #define SHARK_MODELS_KERNELS_NORMALIZED_KERNEL_H 49 template<
class InputType=RealVector>
55 struct InternalState:
public State{
60 boost::shared_ptr<State> stateKxy;
61 std::vector<boost::shared_ptr<State> > stateKxx;
62 std::vector<boost::shared_ptr<State> > stateKyy;
65 kxy.resize(sizeX1,sizeX2);
68 stateKxx.resize(sizeX1);
69 stateKyy.resize(sizeX2);
71 for(std::size_t i = 0; i != sizeX1;++i){
74 for(std::size_t i = 0; i != sizeX2;++i){
95 {
return "NormalizedKernel<" +
m_base->
name() +
">"; }
111 InternalState* state =
new InternalState();
112 return boost::shared_ptr<State>(state);
119 double eval(ConstInputReference x1, ConstInputReference x2)
const{
131 void eval(ConstBatchInputReference
const& batchX1, ConstBatchInputReference
const& batchX2, RealMatrix& result,
State& state)
const{
132 InternalState& s = state.
toState<InternalState>();
134 std::size_t sizeX1 =
size(batchX1);
135 std::size_t sizeX2 =
size(batchX2);
136 s.resize(sizeX1,sizeX2,
m_base);
137 result.resize(sizeX1,sizeX2);
139 m_base->
eval(batchX1, batchX2,s.kxy, *s.stateKxy);
148 RealMatrix singleResult(1,1);
149 for(std::size_t i = 0; i != sizeX1;++i){
150 get(singleBatch,0) =
get(batchX1,i);
151 m_base->
eval(singleBatch,singleBatch,singleResult,*s.stateKxx[i]);
152 s.kxx[i] = singleResult(0,0);
155 for(std::size_t j = 0; j != sizeX2;++j){
156 get(singleBatch,0) =
get(batchX2,j);
157 m_base->
eval(singleBatch,singleBatch,singleResult,*s.stateKyy[j]);
158 s.kyy[j] = singleResult(0,0);
160 RealVector sqrtKyy=sqrt(s.kyy);
171 void eval(ConstBatchInputReference
const& batchX1, ConstBatchInputReference
const& batchX2, RealMatrix& result)
const{
172 std::size_t sizeX1 =
size(batchX1);
173 std::size_t sizeX2 =
size(batchX2);
177 RealVector sqrtKyy(sizeX2);
178 for(std::size_t j = 0; j != sizeX2;++j){
179 sqrtKyy(j) = std::sqrt(
m_base->
eval(
get(batchX2,j),
get(batchX2,j)));
182 for(std::size_t i = 0; i != sizeX1;++i){
183 double sqrtKxx = std::sqrt(
m_base->
eval(
get(batchX2,i),
get(batchX2,i)));
194 ConstBatchInputReference
const& batchX1,
195 ConstBatchInputReference
const& batchX2,
196 RealMatrix
const& coefficients,
201 InternalState
const& s = state.
toState<InternalState>();
202 std::size_t sizeX1 =
size(batchX1);
203 std::size_t sizeX2 =
size(batchX2);
205 RealMatrix weights = coefficients / sqrt(
outer_prod(s.kxx,s.kyy));
210 RealVector wx =
sum_columns(weights) / (2.0 * s.kxx);
211 RealVector wy =
sum_rows(weights) / (2.0 * s.kyy);
217 RealVector subGradient(gradient.size());
219 RealMatrix coeff(1,1);
220 for(std::size_t i = 0; i != sizeX1; ++i){
221 get(singleBatch,0) =
get(batchX1,i);
224 gradient -= subGradient;
226 for(std::size_t j = 0; j != sizeX2; ++j){
227 get(singleBatch,0) =
get(batchX2,j);
230 gradient -= subGradient;
241 ConstBatchInputReference
const& batchX1,
242 ConstBatchInputReference
const& batchX2,
243 RealMatrix
const& coefficientsX2,
245 BatchInputType& gradient
247 InternalState
const& s = state.
toState<InternalState>();
248 std::size_t sizeX1 =
size(batchX1);
250 RealMatrix weights = coefficientsX2 / sqrt(
outer_prod(s.kxx,s.kyy));
260 RealMatrix subGradient;
262 RealMatrix coeff(1,1);
263 for(std::size_t i = 0; i != sizeX1; ++i){
264 get(singleBatch,0) =
get(batchX1,i);