MarkovChain.h
Go to the documentation of this file.
1 /*!
2  *
3  *
4  * \brief -
5  *
6  * \author -
7  * \date -
8  *
9  *
10  * \par Copyright 1995-2015 Shark Development Team
11  *
12  * <BR><HR>
13  * This file is part of Shark.
14  * <http://image.diku.dk/shark/>
15  *
16  * Shark is free software: you can redistribute it and/or modify
17  * it under the terms of the GNU Lesser General Public License as published
18  * by the Free Software Foundation, either version 3 of the License, or
19  * (at your option) any later version.
20  *
21  * Shark is distributed in the hope that it will be useful,
22  * but WITHOUT ANY WARRANTY; without even the implied warranty of
23  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24  * GNU Lesser General Public License for more details.
25  *
26  * You should have received a copy of the GNU Lesser General Public License
27  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28  *
29  */
30 #ifndef SHARK_UNSUPERVISED_RBM_SAMPLING_MARKOVCHAIN_H
31 #define SHARK_UNSUPERVISED_RBM_SAMPLING_MARKOVCHAIN_H
32 
33 #include <shark/Data/Dataset.h>
36 #include "Impl/SampleTypes.h"
37 namespace shark{
38 
39 /// \brief A single Markov chain.
40 ///
41 /// You can run the Markov chain for some sampling steps by applying a transition operator.
42 template<class Operator>
44 private:
45  typedef typename Operator::HiddenSample HiddenSample;
46  typedef typename Operator::VisibleSample VisibleSample;
47 public:
48 
49  ///\brief The MarkovChain can be used to compute several samples at once.
50  static const bool computesBatch = true;
51 
52  ///\brief The type of the RBM the operator is working with.
53  typedef typename Operator::RBM RBM;
54  ///\brief A batch of samples containing hidden and visible samples as well as the energies.
56 
57  ///\brief Mutable reference to an element of the batch.
58  typedef typename SampleBatch::reference reference;
59 
60  ///\brief Immutable reference to an element of the batch.
61  typedef typename SampleBatch::const_reference const_reference;
62 private:
63  ///\brief The batch of samples containing the state of the visible and the hidden units.
64  SampleBatch m_samples;
65  ///\brief The transition operator.
66  Operator m_operator;
67 public:
68 
69  /// \brief Constructor.
70  MarkovChain(RBM* rbm):m_operator(rbm){}
71 
72 
73  /// \brief Sets the number of parallel samples to be evaluated
74  void setBatchSize(std::size_t batchSize){
75  std::size_t visibles=m_operator.rbm()->numberOfVN();
76  std::size_t hiddens=m_operator.rbm()->numberOfHN();
77  m_samples=SampleBatch(batchSize,visibles,hiddens);
78  }
79  std::size_t batchSize(){
80  return m_samples.size();
81  }
82 
83  /// \brief Initializes with data points drawn uniform from the set.
84  ///
85  /// @param dataSet the data set
86  void initializeChain(Data<RealVector> const& dataSet){
87  DiscreteUniform<typename RBM::RngType> uni(m_operator.rbm()->rng(),0,dataSet.numberOfElements()-1);
88  std::size_t visibles=m_operator.rbm()->numberOfVN();
89  RealMatrix sampleData(m_samples.size(),visibles);
90 
91  for(std::size_t i = 0; i != m_samples.size(); ++i){
92  noalias(row(sampleData,i)) = dataSet.element(uni());
93  }
94  initializeChain(sampleData);
95  }
96 
97  /// \brief Initializes with data points from a batch of points
98  ///
99  /// @param sampleData Data set
100  void initializeChain(RealMatrix const& sampleData){
101  m_operator.createSample(m_samples.hidden,m_samples.visible,sampleData);
102  }
103 
104  /// \brief Runs the chain for a given number of steps.
105  ///
106  /// @param numberOfSteps the number of steps
107  void step(unsigned int numberOfSteps){
108  m_operator.stepVH(m_samples.hidden,m_samples.visible,numberOfSteps,blas::repeat(1.0,batchSize()));
109  }
110 
111  /// \brief Returns the current sample of the Markov chain.
112  const_reference sample()const{
113  return const_reference(m_samples,0);
114  }
115 
116  /// \brief Returns the current batch of samples of the Markov chain.
117  SampleBatch const& samples()const{
118  return m_samples;
119  }
120 
121  /// \brief Returns the current batch of samples of the Markov chain.
122  SampleBatch& samples(){
123  return m_samples;
124  }
125 
126  /// \brief Returns the transition operator of the Markov chain.
127  Operator const& transitionOperator()const{
128  return m_operator;
129  }
130 
131  /// \brief Returns the transition operator of the Markov chain.
132  Operator& transitionOperator(){
133  return m_operator;
134  }
135 };
136 
137 }
138 #endif