35 #ifndef SHARK_ALGORITHMS_TRAINERS_NB_CLASSIFIER_TRAINER_H 36 #define SHARK_ALGORITHMS_TRAINERS_NB_CLASSIFIER_TRAINER_H 43 #include <boost/foreach.hpp> 44 #include "boost/tuple/tuple.hpp" 57 template <
class InputType = RealVector,
class OutputType =
unsigned int>
64 typedef typename InputType::value_type InputValueType;
70 {
return "NBClassifierTrainer"; }
78 std::size_t classSize;
79 std::size_t featureSize;
80 boost::tie(classSize, featureSize) = model.
getDistSize();
85 std::vector<InputValueType> buffer;
89 for (std::size_t i = 0; i < classSize; ++i)
91 for (std::size_t j = 0; j < featureSize; ++j)
95 getFeatureSample(buffer, dataset, i, j);
96 m_distTrainer.
train(dist, buffer);
101 const std::vector<std::size_t> occuranceCounter =
classSizes(dataset);
104 for (std::size_t i = 0; i < classSize; ++i) {
105 model.
setClassPrior(i, occuranceCounter[i] / totalClassOccurances);
121 void getFeatureSample(
122 std::vector<InputValueType>& samples,
124 OutputType classIndex,
125 std::size_t featureIndex
127 SHARK_CHECK(samples.empty(),
"The output buffer should be cleaned before usage usually.");
129 BOOST_FOREACH(reference elem, dataset.
elements()){
130 if (elem.label == classIndex)
131 samples.push_back(elem.input(featureIndex));
141 #endif // SHARK_ALGORITHMS_TRAINERS_NB_CLASSIFIER_TRAINER_H