NBClassifierTrainer.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Trainer of Naive Bayes classifier
6  *
7  *
8  *
9  *
10  * \author B. Li
11  * \date 2012
12  *
13  *
14  * \par Copyright 1995-2015 Shark Development Team
15  *
16  * <BR><HR>
17  * This file is part of Shark.
18  * <http://image.diku.dk/shark/>
19  *
20  * Shark is free software: you can redistribute it and/or modify
21  * it under the terms of the GNU Lesser General Public License as published
22  * by the Free Software Foundation, either version 3 of the License, or
23  * (at your option) any later version.
24  *
25  * Shark is distributed in the hope that it will be useful,
26  * but WITHOUT ANY WARRANTY; without even the implied warranty of
27  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28  * GNU Lesser General Public License for more details.
29  *
30  * You should have received a copy of the GNU Lesser General Public License
31  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
32  *
33  */
34 //===========================================================================
35 #ifndef SHARK_ALGORITHMS_TRAINERS_NB_CLASSIFIER_TRAINER_H
36 #define SHARK_ALGORITHMS_TRAINERS_NB_CLASSIFIER_TRAINER_H
37 
40 #include "shark/Core/Exception.h"
42 
43 #include <boost/foreach.hpp>
44 #include "boost/tuple/tuple.hpp"
45 #include <cmath>
46 
47 namespace shark {
48 
49 /// @brief Trainer for naive Bayes classifier
50 ///
51 /// Basically NB trainer needs to figure out two things for NB classifier:
52 /// (1) Prior probability of each class
53 /// (2) Parameters for distributions of each feature given each class
54 ///
55 /// @tparam InputType the type of feature vector
56 /// @tparam OutputType the type of class
57 template <class InputType = RealVector, class OutputType = unsigned int>
59 :public AbstractTrainer<NBClassifier<InputType, OutputType> >
60 {
61 private:
62 
64  typedef typename InputType::value_type InputValueType;
65 
66 public:
67 
68  /// \brief From INameable: return the class name.
69  std::string name() const
70  { return "NBClassifierTrainer"; }
71 
72  /// @see AbstractTrainer::train
73  void train(NBClassifierType& model, LabeledData<InputType, OutputType> const& dataset)
74  {
75  SIZE_CHECK(dataset.numberOfElements() > 0u);
76 
77  // Get size of class/feature
78  std::size_t classSize;
79  std::size_t featureSize;
80  boost::tie(classSize, featureSize) = model.getDistSize();
81  SHARK_CHECK(classSize == numberOfClasses(dataset), "Number of classes in dataset and model should match.");
82  SHARK_CHECK(featureSize == inputDimension(dataset), "Number of features in dataset and model should match.");
83 
84  // Initialize trainer & buffer
85  std::vector<InputValueType> buffer;
86  buffer.reserve(dataset.numberOfElements() / classSize);
87 
88  // Train individual feature distribution
89  for (std::size_t i = 0; i < classSize; ++i)
90  {
91  for (std::size_t j = 0; j < featureSize; ++j)
92  {
93  AbstractDistribution& dist = model.getFeatureDist(i, j);
94  buffer.clear();
95  getFeatureSample(buffer, dataset, i, j);
96  m_distTrainer.train(dist, buffer);
97  }
98  }
99 
100  // Figure out class distribution and add it to the model
101  const std::vector<std::size_t> occuranceCounter = classSizes(dataset);
102 
103  const double totalClassOccurances = dataset.numberOfElements();
104  for (std::size_t i = 0; i < classSize; ++i) {
105  model.setClassPrior(i, occuranceCounter[i] / totalClassOccurances);
106  }
107  }
108 
109  /// Return the distribution trainer container which allows user to check or set individual distribution trainer
110  DistTrainerContainer& getDistTrainerContainer() { return m_distTrainer; }
111 
112 private:
113 
114  /// Get samples for a given feature in a given class
115  /// @param samples[out] the container which will store the samples we want to get
116  /// @param dataset the entire dataset
117  /// @param classIndex the index of class we are interested in
118  /// @param featureIndex the index of feature we are interested in
119  ///
120  /// @note This can/should be optimized
121  void getFeatureSample(
122  std::vector<InputValueType>& samples,
123  const LabeledData<InputType, OutputType>& dataset,
124  OutputType classIndex,
125  std::size_t featureIndex
126  ) const{
127  SHARK_CHECK(samples.empty(), "The output buffer should be cleaned before usage usually.");
129  BOOST_FOREACH(reference elem, dataset.elements()){
130  if (elem.label == classIndex)
131  samples.push_back(elem.input(featureIndex));
132  }
133  }
134 
135  /// Generic distribution trainer
136  GenericDistTrainer m_distTrainer;
137 };
138 
139 } // namespace shark {
140 
141 #endif // SHARK_ALGORITHMS_TRAINERS_NB_CLASSIFIER_TRAINER_H