30 #ifndef SHARK_UNSUPERVISED_RBM_GRADIENTAPPROXIMATIONS_CONTRASTIVEDIVERGENCE_H 31 #define SHARK_UNSUPERVISED_RBM_GRADIENTAPPROXIMATIONS_CONTRASTIVEDIVERGENCE_H 43 template<
class Operator>
46 typedef typename Operator::RBM
RBM;
52 : mpe_rbm(rbm),m_operator(rbm)
53 , m_k(1), m_numBatches(0),m_regularizer(0){
63 {
return "ContrastiveDivergence"; }
80 return mpe_rbm->parameterVector();
87 return mpe_rbm->numberOfParameters();
105 m_regularizer = regularizer;
106 m_regularizationStrength = factor;
114 mpe_rbm->setParameterVector(parameter);
115 derivative.resize(mpe_rbm->numberOfParameters());
118 std::size_t batchesForTraining = m_numBatches > 0? m_numBatches: m_data.
numberOfBatches();
119 std::size_t elements = 0;
127 std::random_shuffle(batchIds.begin(),batchIds.end(),
uni);
128 for(std::size_t i = 0; i != batchesForTraining; ++i){
129 elements += m_data.
batch(batchIds[i]).size1();
133 std::size_t threads = std::min<std::size_t>(batchesForTraining,
SHARK_NUM_THREADS);
134 std::size_t
numBatches = batchesForTraining/threads;
141 std::size_t threadElements = 0;
144 std::size_t batchEnd = (t== (int)threads-1)? batchesForTraining : batchStart+
numBatches;
145 for(std::size_t i = batchStart; i != batchEnd; ++i){
146 RealMatrix
const& batch = m_data.
batch(batchIds[i]);
147 threadElements += batch.size1();
150 typename Operator::HiddenSampleBatch hiddenBatch(batch.size1(),mpe_rbm->numberOfHN());
151 typename Operator::VisibleSampleBatch visibleBatch(batch.size1(),mpe_rbm->numberOfVN());
153 visibleBatch.state = batch;
154 m_operator.precomputeHidden(hiddenBatch,visibleBatch,
blas::repeat(1.0,batch.size1()));
155 m_operator.sampleHidden(hiddenBatch);
156 empiricalAverage.addVH(hiddenBatch,visibleBatch);
158 for(std::size_t step = 0; step != m_k; ++step){
159 m_operator.precomputeVisible(hiddenBatch, visibleBatch,
blas::repeat(1.0,batch.size1()));
160 m_operator.sampleVisible(visibleBatch);
161 m_operator.precomputeHidden(hiddenBatch, visibleBatch,
blas::repeat(1.0,batch.size1()));
163 m_operator.sampleHidden(hiddenBatch);
166 modelAverage.addVH(hiddenBatch,visibleBatch);
169 double weight = threadElements/double(elements);
170 noalias(derivative) += weight*(modelAverage.result() - empiricalAverage.result());
178 noalias(derivative) += m_regularizationStrength*regularizerDerivative;
181 return std::numeric_limits<double>::quiet_NaN();
189 std::size_t m_numBatches;
192 double m_regularizationStrength;