MultiNomialDistribution.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief Implements a multinomial distribution
5  *
6  *
7  *
8  * \author O.Krause
9  * \date 2016
10  *
11  *
12  * \par Copyright 1995-2015 Shark Development Team
13  *
14  * <BR><HR>
15  * This file is part of Shark.
16  * <http://image.diku.dk/shark/>
17  *
18  * Shark is free software: you can redistribute it and/or modify
19  * it under the terms of the GNU Lesser General Public License as published
20  * by the Free Software Foundation, either version 3 of the License, or
21  * (at your option) any later version.
22  *
23  * Shark is distributed in the hope that it will be useful,
24  * but WITHOUT ANY WARRANTY; without even the implied warranty of
25  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26  * GNU Lesser General Public License for more details.
27  *
28  * You should have received a copy of the GNU Lesser General Public License
29  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30  *
31  */
32 #ifndef SHARK_STATISTICS_MULTINOMIALDISTRIBUTION_H
33 #define SHARK_STATISTICS_MULTINOMIALDISTRIBUTION_H
34 
36 #include <shark/LinAlg/Cholesky.h>
37 #include <shark/Rng/GlobalRng.h>
38 
39 namespace shark {
40 
41 /// \brief Implements a multinomial distribution.
42 ///
43 /// A multinomial distribution is a discrete distribution with states 0,...,N-1
44 /// and probabilities p_i for state i with sum_i p_i = 1. This implementation uses
45 /// the fast alias method (Kronmal and Peterson,1979) to draw the numbers in
46 /// constant time. Setup is O(N) and also quite fast. It is advisable
47 /// to use this method to draw many numbers in succession.
48 ///
49 /// The idea of the alias method is to pair a state with high probability with a state with low
50 /// probability. A high probability state can in this case be included in several pairs. To draw,
51 /// first one of the states is selected and afterwards a coin toss decides which element of the pair
52 /// is taken.
54 public:
55  typedef std::size_t result_type;
56 
58 
59  /// \brief Constructor
60  /// \param [in] probabilities Probability vector
62  : m_probabilities(probabilities){
63  update();
64  }
65 
66  /// \brief Stores/Restores the distribution from the supplied archive.
67  /// \param [in,out] ar The archive to read from/write to.
68  /// \param [in] version Currently unused.
69  template<typename Archive>
70  void serialize( Archive & ar, const unsigned int version ) {
71  ar & BOOST_SERIALIZATION_NVP( m_probabilities );
72  ar & BOOST_SERIALIZATION_NVP( m_q );
73  ar & BOOST_SERIALIZATION_NVP( m_J );
74  }
75 
76  /// \brief Accesses the probabilityvector defining the distribution.
77  RealVector const& probabilities() const {
78  return m_probabilities;
79  }
80 
81  /// \brief Accesses a mutable reference to the probability vector
82  /// defining the distribution. Allows for l-value semantics.
83  ///
84  /// ATTENTION: If the reference is altered, update needs to be called manually.
85  RealVector& probabilities() {
86  return m_probabilities;
87  }
88 
89  /// \brief Samples the distribution.
90  template<class RngType>
91  result_type operator()(RngType& rng) const {
92  std::size_t numStates = m_probabilities.size();
93 
94  std::size_t index = discrete(rng,0,numStates-1);
95 
96  if(coinToss(rng, m_q[index]))
97  return index;
98  else
99  return m_J[index];
100  }
101 
102 
103  void update() {
104  std::size_t numStates = m_probabilities.size();
105  m_q.resize(numStates);
106  m_J.resize(numStates);
107  m_probabilities/=sum(m_probabilities);
108 
109  // Sort the data into the outcomes with probabilities
110  // that are larger and smaller than 1/K.
111  std::deque<std::size_t> smaller;
112  std::deque<std::size_t> larger;
113  for(std::size_t i = 0;i != numStates; ++i){
114  m_q(i) = numStates*m_probabilities(i);
115  if(m_q(i) < 1.0)
116  smaller.push_back(i);
117  else
118  larger.push_back(i);
119  }
120  // Loop though and create little binary mixtures that
121  // appropriately allocate the larger outcomes over the
122  // overall uniform mixture.
123  while(!smaller.empty() && !larger.empty()){
124  std::size_t smallIndex = smaller.front();
125  std::size_t largeIndex = larger.front();
126  smaller.pop_front();
127  larger.pop_front();
128 
129  m_J[smallIndex] = largeIndex;
130  m_q[largeIndex] -= 1.0 - m_q[smallIndex];
131 
132  if(m_q[largeIndex] < 1.0)
133  smaller.push_back(largeIndex);
134  else
135  larger.push_back(largeIndex);
136  }
137  for(std::size_t i = 0; i != larger.size(); ++i){
138  m_q[larger[i]]=std::min(m_q[larger[i]],1.0);
139  }
140  }
141 
142 private:
143  RealVector m_probabilities; ///< probability of every state.
144  RealVector m_q; ///< probability of the pair (i,J[i]) to draw an.
145  blas::vector<std::size_t> m_J; ///< defines the second element of the pair (i,J[i])
146 };
147 }
148 
149 #endif