Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
include
shark
Models
MeanModel.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Implements the Mean Model that can be used for ensemble classifiers
6
*
7
*
8
*
9
* \author Kang Li, O. Krause
10
* \date 2014
11
*
12
*
13
* \par Copyright 1995-2015 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <http://image.diku.dk/shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
35
#ifndef SHARK_MODELS_MEANMODEL_H
36
#define SHARK_MODELS_MEANMODEL_H
37
38
namespace
shark
{
39
///
40
/// \brief Calculates the weighted mean of a set of models
41
///
42
template
<
class
ModelType>
43
class
MeanModel
:
public
AbstractModel
<typename ModelType::InputType, typename ModelType::OutputType>
44
{
45
private
:
46
typedef
AbstractModel<typename ModelType::InputType, typename ModelType::OutputType>
base_type
;
47
public
:
48
49
/// Constructor
50
MeanModel
():
m_weightSum
(0){}
51
52
std::string
name
()
const
53
{
return
"MeanModel"
; }
54
55
using
base_type::eval
;
56
void
eval
(
typename
base_type::BatchInputType
const
& patterns,
typename
base_type::BatchOutputType
& outputs)
const
{
57
m_models
[0].eval(patterns,outputs);
58
outputs *=
m_weight
[0];
59
for
(std::size_t i = 1; i !=
m_models
.size(); i++)
60
noalias
(outputs) +=
m_weight
[i] *
m_models
[i](patterns);
61
outputs /=
m_weightSum
;
62
}
63
64
void
eval
(
typename
base_type::BatchInputType
const
& patterns,
typename
base_type::BatchOutputType
& outputs,
State
& state)
const
{
65
eval
(patterns,outputs);
66
}
67
68
69
/// This model does not have any parameters.
70
RealVector
parameterVector
()
const
{
71
return
RealVector();
72
}
73
74
/// This model does not have any parameters
75
void
setParameterVector
(
const
RealVector& param) {
76
SHARK_ASSERT
(param.size() == 0);
77
}
78
void
read
(
InArchive
& archive){
79
archive >>
m_models
;
80
archive >>
m_weight
;
81
archive >>
m_weightSum
;
82
}
83
void
write
(
OutArchive
& archive)
const
{
84
archive <<
m_models
;
85
archive <<
m_weight
;
86
archive <<
m_weightSum
;
87
}
88
89
/// \brief Removes all models from the ensemble
90
void
clearModels
(){
91
m_models
.clear();
92
m_weight
.clear();
93
m_weightSum
= 0.0;
94
}
95
96
/// \brief Adds a new model to the ensemble.
97
///
98
/// \param model the new model
99
/// \param weight weight of the model. must be > 0
100
void
addModel
(ModelType
const
& model,
double
weight
= 1.0){
101
SHARK_CHECK
(
weight
> 0,
"Weights must be positive"
);
102
m_models
.push_back(model);
103
m_weight
.push_back(
weight
);
104
m_weightSum
+=
weight
;
105
}
106
107
/// \brief Returns the weight of the i-th model
108
double
const
&
weight
(std::size_t i)
const
{
109
return
m_weight
[i];
110
}
111
112
/// \brief sets the weight of the i-th model
113
void
setWeight
(std::size_t i,
double
newWeight){
114
m_weightSum
=newWeight -
m_weight
[i];
115
m_weight[i] = newWeight;
116
}
117
118
/// \brief Returns the number of models.
119
std::size_t
numberOfModels
()
const
{
120
return
m_models
.size();
121
}
122
123
protected
:
124
/// collection of models.
125
std::vector<ModelType>
m_models
;
126
127
/// Weight of the mean.
128
std::vector<double>
m_weight
;
129
130
/// Total sum of weights.
131
double
m_weightSum
;
132
};
133
134
135
}
136
#endif // SHARK_MODELS_MEANMODEL_H