GenericDistTrainer.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief Implementations of various distribution trainers.
5  *
6  *
7  *
8  * \author B. Li
9  * \date 2012
10  *
11  *
12  * \par Copyright 1995-2015 Shark Development Team
13  *
14  * <BR><HR>
15  * This file is part of Shark.
16  * <http://image.diku.dk/shark/>
17  *
18  * Shark is free software: you can redistribute it and/or modify
19  * it under the terms of the GNU Lesser General Public License as published
20  * by the Free Software Foundation, either version 3 of the License, or
21  * (at your option) any later version.
22  *
23  * Shark is distributed in the hope that it will be useful,
24  * but WITHOUT ANY WARRANTY; without even the implied warranty of
25  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26  * GNU Lesser General Public License for more details.
27  *
28  * You should have received a copy of the GNU Lesser General Public License
29  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30  *
31  */
32 #ifndef SHARK_ALGORITHMS_TRAINERS_DISTRIBUTION_GENERIC_DIST_TRAINER_H
33 #define SHARK_ALGORITHMS_TRAINERS_DISTRIBUTION_GENERIC_DIST_TRAINER_H
34 
37 #include "shark/Rng/Normal.h"
38 #include "shark/Rng/Rng.h"
39 #include "shark/Rng/Uniform.h"
40 
41 namespace shark {
42 
43 /// The trainer which is smart enough to train different kinds of distributions
44 ///
45 /// @note all train functions should be reentrant
47 :
49 {
50 public:
51 
52  /// Train an abstract distribution
53  /// @param abstractDist the distribution we want to train
54  /// @param input the input data used for training the dist
55  /// @throw throw shark exception if training attempt for this distribution failed
56  void train(AbstractDistribution& abstractDist, const std::vector<double>& input) const
57  {
58  // We have to do manual dispatching here unless distributions are trainer-aware/-friendly
59 
60  if (tryTrain<Normal<DefaultRngType> >(abstractDist, getNormalTrainer(), input))
61  return;
62  if (tryTrain<Normal<FastRngType> >(abstractDist, getNormalTrainer(), input))
63  return;
64 
65  // Other distributions go here
66 
67  throw SHARKEXCEPTION("No trainer for this distribution.");
68  }
69 
70 private:
71 
72  /// Try to train an abstract distribution with given concrete distribution type
73  /// @param abstractDist the abstract distribution
74  /// @param trainer the trainer to be used for training the distribution
75  /// @param input the input data
76  /// @tparam DistType the type of concrete distribution
77  /// @tparam TrainerType the type of trainer
78  /// @return true if the training attempt succeeded, false otherwise
79  template <typename DistType, typename TrainerType>
80  bool tryTrain(AbstractDistribution& abstractDist, const TrainerType& trainer, const std::vector<double>& input) const
81  {
82  DistType* dist = dynamic_cast<DistType*>(&abstractDist);
83  if (dist)
84  {
85  trainer.train(*dist, input);
86  return true;
87  }
88  else
89  {
90  return false;
91  }
92  }
93 };
94 
95 } // namespace shark {
96 
97 #endif // SHARK_ALGORITHMS_TRAINERS_DISTRIBUTION_GENERIC_DIST_TRAINER_H