Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
include
shark
Algorithms
Trainers
Distribution
GenericDistTrainer.h
Go to the documentation of this file.
1
/*!
2
*
3
*
4
* \brief Implementations of various distribution trainers.
5
*
6
*
7
*
8
* \author B. Li
9
* \date 2012
10
*
11
*
12
* \par Copyright 1995-2015 Shark Development Team
13
*
14
* <BR><HR>
15
* This file is part of Shark.
16
* <http://image.diku.dk/shark/>
17
*
18
* Shark is free software: you can redistribute it and/or modify
19
* it under the terms of the GNU Lesser General Public License as published
20
* by the Free Software Foundation, either version 3 of the License, or
21
* (at your option) any later version.
22
*
23
* Shark is distributed in the hope that it will be useful,
24
* but WITHOUT ANY WARRANTY; without even the implied warranty of
25
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26
* GNU Lesser General Public License for more details.
27
*
28
* You should have received a copy of the GNU Lesser General Public License
29
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
30
*
31
*/
32
#ifndef SHARK_ALGORITHMS_TRAINERS_DISTRIBUTION_GENERIC_DIST_TRAINER_H
33
#define SHARK_ALGORITHMS_TRAINERS_DISTRIBUTION_GENERIC_DIST_TRAINER_H
34
35
#include "
shark/Algorithms/Trainers/Distribution/DistTrainerContainer.h
"
36
#include "
shark/Algorithms/Trainers/Distribution/NormalTrainer.h
"
37
#include "
shark/Rng/Normal.h
"
38
#include "
shark/Rng/Rng.h
"
39
#include "
shark/Rng/Uniform.h
"
40
41
namespace
shark
{
42
43
/// The trainer which is smart enough to train different kinds of distributions
44
///
45
/// @note all train functions should be reentrant
46
class
GenericDistTrainer
47
:
48
public
DistTrainerContainer
49
{
50
public
:
51
52
/// Train an abstract distribution
53
/// @param abstractDist the distribution we want to train
54
/// @param input the input data used for training the dist
55
/// @throw throw shark exception if training attempt for this distribution failed
56
void
train
(
AbstractDistribution
& abstractDist,
const
std::vector<double>& input)
const
57
{
58
// We have to do manual dispatching here unless distributions are trainer-aware/-friendly
59
60
if
(tryTrain<
Normal<DefaultRngType>
>(abstractDist,
getNormalTrainer
(), input))
61
return
;
62
if
(tryTrain<
Normal<FastRngType>
>(abstractDist,
getNormalTrainer
(), input))
63
return
;
64
65
// Other distributions go here
66
67
throw
SHARKEXCEPTION
(
"No trainer for this distribution."
);
68
}
69
70
private
:
71
72
/// Try to train an abstract distribution with given concrete distribution type
73
/// @param abstractDist the abstract distribution
74
/// @param trainer the trainer to be used for training the distribution
75
/// @param input the input data
76
/// @tparam DistType the type of concrete distribution
77
/// @tparam TrainerType the type of trainer
78
/// @return true if the training attempt succeeded, false otherwise
79
template
<
typename
DistType,
typename
TrainerType>
80
bool
tryTrain(
AbstractDistribution
& abstractDist,
const
TrainerType& trainer,
const
std::vector<double>& input)
const
81
{
82
DistType* dist =
dynamic_cast<
DistType*
>
(&abstractDist);
83
if
(dist)
84
{
85
trainer.train(*dist, input);
86
return
true
;
87
}
88
else
89
{
90
return
false
;
91
}
92
}
93
};
94
95
}
// namespace shark {
96
97
#endif // SHARK_ALGORITHMS_TRAINERS_DISTRIBUTION_GENERIC_DIST_TRAINER_H