35 #ifndef SHARK_MODEL_CONCATENATEDMODEL_H 36 #define SHARK_MODEL_CONCATENATEDMODEL_H 39 #include <boost/scoped_ptr.hpp> 40 #include <boost/serialization/scoped_ptr.hpp> 51 template<
class InputType,
class OutputType>
52 class ConcatenatedModelWrapperBase:
public AbstractModel<InputType,OutputType>{
54 ConcatenatedModelWrapperBase():m_optimizeFirst(true), m_optimizeSecond(true){}
55 virtual ConcatenatedModelWrapperBase<InputType,OutputType>* clone()
const = 0;
57 bool optimizeFirstModelParameters()
const{
58 return m_optimizeFirst;
61 bool& optimizeFirstModelParameters(){
62 return m_optimizeFirst;
65 bool optimizeSecondModelParameters()
const{
66 return m_optimizeSecond;
69 bool& optimizeSecondModelParameters(){
70 return m_optimizeSecond;
74 bool m_optimizeSecond;
80 template<
class InputType,
class IntermediateType,
class OutputType>
81 class ConcatenatedModelWrapper :
public ConcatenatedModelWrapperBase<InputType, OutputType> {
84 AbstractModel<InputType,IntermediateType>* m_firstModel;
85 AbstractModel<IntermediateType,OutputType>* m_secondModel;
87 typedef ConcatenatedModelWrapperBase<InputType, OutputType> base_type;
88 using base_type::m_optimizeFirst;
89 using base_type::m_optimizeSecond;
91 struct InternalState:
public State{
92 BatchIntermediateType intermediateResult;
93 boost::shared_ptr<State> firstModelState;
94 boost::shared_ptr<State> secondModelState;
97 typedef typename base_type::BatchInputType BatchInputType;
99 typedef typename base_type::BatchOutputType BatchOutputType;
100 ConcatenatedModelWrapper(
101 AbstractModel<InputType, IntermediateType>* firstModel,
102 AbstractModel<IntermediateType, OutputType>* secondModel)
103 : m_firstModel(firstModel), m_secondModel(secondModel)
105 if (firstModel->hasFirstParameterDerivative()
106 && secondModel->hasFirstParameterDerivative()
107 && secondModel ->hasFirstInputDerivative())
109 this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
112 if (firstModel->hasFirstInputDerivative()
113 && secondModel->hasFirstInputDerivative())
115 this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
120 std::string name()
const 121 {
return "Concatenation<" + m_firstModel->name() +
"," + m_secondModel->name() +
">"; }
123 ConcatenatedModelWrapperBase<InputType, OutputType>* clone()
const{
124 return new ConcatenatedModelWrapper<InputType, IntermediateType, OutputType>(*this);
127 RealVector parameterVector()
const {
128 RealVector params(numberOfParameters());
129 if(m_optimizeFirst && m_optimizeSecond)
131 else if (m_optimizeFirst)
132 params = m_firstModel->parameterVector();
133 else if (m_optimizeSecond)
134 params = m_secondModel->parameterVector();
138 void setParameterVector(RealVector
const& newParameters) {
139 if(m_optimizeFirst && m_optimizeSecond)
141 else if (m_optimizeFirst)
142 m_firstModel->setParameterVector(newParameters);
143 else if (m_optimizeSecond)
144 m_secondModel->setParameterVector(newParameters);
148 boost::shared_ptr<State> createState()
const{
149 InternalState* state =
new InternalState();
150 boost::shared_ptr<State> ptrState(state);
151 state->firstModelState = m_firstModel->createState();
152 state->secondModelState = m_secondModel->createState();
156 std::size_t numberOfParameters()
const {
157 std::size_t numParams = 0;
159 numParams += m_firstModel->numberOfParameters();
160 if (m_optimizeSecond)
161 numParams += m_secondModel->numberOfParameters();
166 void eval( BatchInputType
const& patterns, BatchOutputType& outputs)
const{
168 (*m_firstModel)(patterns),
173 void eval( BatchInputType
const& patterns, BatchOutputType& outputs, State& state)
const{
174 InternalState& s = state.toState<InternalState>();
175 m_firstModel->eval(patterns, s.intermediateResult,*s.firstModelState);
176 m_secondModel->eval(s.intermediateResult, outputs,*s.secondModelState);
179 void weightedParameterDerivative(
180 BatchInputType
const& patterns, BatchOutputType
const& coefficients, State
const& state, RealVector& gradient
182 InternalState
const& s = state.toState<InternalState>();
185 std::size_t numParamsFirst = m_firstModel->numberOfParameters();
186 if(m_optimizeFirst && m_optimizeSecond && numParamsFirst != 0){
187 RealVector firstParameterDerivative;
188 BatchIntermediateType secondInputDerivative;
189 RealVector secondParameterDerivative;
191 m_secondModel->weightedDerivatives(
192 s.intermediateResult,coefficients,*s.secondModelState,
193 secondParameterDerivative,secondInputDerivative
195 m_firstModel->weightedParameterDerivative(patterns,secondInputDerivative,*s.firstModelState,firstParameterDerivative);
197 gradient.resize(numberOfParameters());
198 init(gradient)<<firstParameterDerivative,secondParameterDerivative;
199 }
else if(m_optimizeFirst && numParamsFirst != 0){
200 RealVector firstParameterDerivative;
201 BatchIntermediateType secondInputDerivative;
203 m_secondModel->weightedInputDerivative(
204 s.intermediateResult,coefficients,*s.secondModelState,secondInputDerivative
206 m_firstModel->weightedParameterDerivative(patterns,secondInputDerivative,*s.firstModelState,gradient);
207 }
else if(m_optimizeSecond){
208 m_secondModel->weightedParameterDerivative(
209 s.intermediateResult,coefficients,*s.secondModelState,
217 void weightedInputDerivative(
218 BatchInputType
const& patterns, BatchOutputType
const& coefficients, State
const& state, BatchOutputType& gradient
220 InternalState
const& s = state.toState<InternalState>();
221 BatchIntermediateType secondInputDerivative;
222 m_secondModel->weightedInputDerivative(s.intermediateResult, coefficients, *s.secondModelState, secondInputDerivative);
223 m_firstModel->weightedInputDerivative(patterns, secondInputDerivative, *s.firstModelState, gradient);
227 virtual void weightedDerivatives(
228 BatchInputType
const & patterns,
229 BatchOutputType
const & coefficients,
231 RealVector& parameterDerivative,
232 BatchInputType& inputDerivative
234 InternalState
const& s = state.toState<InternalState>();
235 std::size_t firstParam=m_firstModel->numberOfParameters();
236 std::size_t secondParam=m_secondModel->numberOfParameters();
237 parameterDerivative.resize(firstParam+secondParam);
239 RealVector firstParameterDerivative;
240 BatchIntermediateType secondInputDerivative;
241 RealVector secondParameterDerivative;
242 if(m_optimizeSecond){
243 m_secondModel->weightedDerivatives(
244 s.intermediateResult, coefficients, *s.firstModelState, secondParameterDerivative, secondInputDerivative
247 m_secondModel->weightedInputDerivative(
248 s.intermediateResult, coefficients, *s.firstModelState, secondInputDerivative
252 m_firstModel->weightedDerivatives(
253 patterns, secondInputDerivative, *s.secondModelState, parameterDerivative, inputDerivative
256 m_firstModel->weightedInputDerivative(
257 patterns, secondInputDerivative, *s.secondModelState, inputDerivative
261 parameterDerivative.resize(firstParam+secondParam);
262 init(parameterDerivative)<<firstParameterDerivative,secondParameterDerivative;
266 m_firstModel->read(archive);
267 m_secondModel->read(archive);
268 archive >> m_optimizeFirst;
269 archive >> m_optimizeSecond;
274 m_firstModel->write(archive);
275 m_secondModel->write(archive);
276 archive << m_optimizeFirst;
277 archive << m_optimizeSecond;
286 template<
class InputType,
class IntermediateType,
class OutputType>
287 class ConcatenatedModelList:
public ConcatenatedModelWrapper<InputType,IntermediateType,OutputType>{
289 typedef ConcatenatedModelWrapper<InputType,IntermediateType,OutputType> base_type;
290 typedef ConcatenatedModelWrapperBase<InputType,IntermediateType> FirstModelType;
293 ConcatenatedModelList(
294 const FirstModelType& firstModel,
295 AbstractModel<IntermediateType, OutputType>* secondModel
296 ):base_type(firstModel.clone(),secondModel){
297 if (base_type::m_firstModel->hasFirstParameterDerivative()
298 && secondModel->hasFirstParameterDerivative()
299 && secondModel ->hasFirstInputDerivative())
301 this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
304 if (base_type::m_firstModel->hasFirstInputDerivative()
305 && secondModel->hasFirstInputDerivative())
307 this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
311 ~ConcatenatedModelList(){
312 delete base_type::m_firstModel;
316 std::string name()
const 317 {
return "Concatenation<" + base_type::m_firstModel->name() +
"," + base_type::m_secondModel->name() +
">"; }
319 ConcatenatedModelWrapperBase<InputType, OutputType>* clone()
const{
320 return new ConcatenatedModelList<InputType, IntermediateType, OutputType>(
321 *
static_cast<FirstModelType*
>(base_type::m_firstModel),
322 base_type::m_secondModel
332 template<
class InputT,
class IntermediateT,
class OutputT>
333 detail::ConcatenatedModelWrapper<InputT,IntermediateT,OutputT>
335 return detail::ConcatenatedModelWrapper<InputT,IntermediateT,OutputT> (&firstModel,&secondModel);
339 template<
class InputT,
class IntermediateT,
class OutputT>
340 detail::ConcatenatedModelList<InputT,IntermediateT,OutputT>
342 const detail::ConcatenatedModelWrapperBase<InputT,IntermediateT>& firstModel,
345 return detail::ConcatenatedModelList<InputT,IntermediateT,OutputT> (firstModel,&secondModel);
361 template<
class InputType,
class OutputType>
364 boost::scoped_ptr<detail::ConcatenatedModelWrapperBase<InputType, OutputType> > m_wrapper;
375 m_wrapper.reset(
new detail::ConcatenatedModelWrapper<InputType, T, OutputType>(firstModel, secondModel));
376 if (m_wrapper->hasFirstParameterDerivative()){
377 this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
380 if (m_wrapper->hasFirstInputDerivative()){
381 this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
386 m_wrapper.reset(wrapper.clone());
387 if (m_wrapper->hasFirstParameterDerivative()){
388 this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
391 if (m_wrapper->hasFirstInputDerivative()){
392 this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
397 m_wrapper.reset(wrapper.clone());
398 if (m_wrapper->hasFirstParameterDerivative()){
399 this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
402 if (m_wrapper->hasFirstInputDerivative()){
403 this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
413 return m_wrapper->optimizeFirstModelParameters();
420 return m_wrapper->optimizeFirstModelParameters();
427 return m_wrapper->optimizeSecondModelParameters();
434 return m_wrapper->optimizeSecondModelParameters();
438 :m_wrapper(src.m_wrapper->clone()) {
444 {
return m_wrapper->name(); }
448 swap(m_wrapper,copy.m_wrapper);
454 return m_wrapper->parameterVector();
458 m_wrapper->setParameterVector(newParameters);
462 return m_wrapper->numberOfParameters();
466 return m_wrapper->createState();
469 using base_type::eval;
470 void eval(BatchInputType
const& patterns, BatchOutputType& outputs)
const {
471 m_wrapper->eval(patterns, outputs);
473 void eval(BatchInputType
const& patterns, BatchOutputType& outputs,
State& state)
const {
474 m_wrapper->eval(patterns, outputs, state);
478 BatchInputType
const& patterns, BatchOutputType
const& coefficients,
State const& state, RealVector& gradient
480 m_wrapper->weightedParameterDerivative(patterns, coefficients, state, gradient);
484 BatchInputType
const& patterns, BatchOutputType
const& coefficients,
State const& state, BatchOutputType& derivatives
486 m_wrapper->weightedInputDerivative(patterns, coefficients, state, derivatives);
490 BatchInputType
const & patterns,
491 BatchOutputType
const & coefficients,
493 RealVector& parameterDerivative,
494 BatchInputType& inputDerivative
496 m_wrapper->weightedDerivatives(patterns, coefficients, state, parameterDerivative,inputDerivative);
501 m_wrapper->read(archive);
506 m_wrapper->write(archive);