45 #ifndef SHARK_DATA_WEIGHTED_DATASET_H 46 #define SHARK_DATA_WEIGHTED_DATASET_H 52 template <
class DataContainerT>
53 class BaseWeightedDataset :
public ISerializable
56 typedef BaseWeightedDataset<DataContainerT> self_type;
58 typedef typename DataContainerT::element_type DataType;
59 typedef double WeightType;
60 typedef DataContainerT DataContainer;
61 typedef Data<WeightType> WeightContainer;
62 typedef typename DataContainer::IndexSet IndexSet;
65 typedef WeightedDataPair<
70 typedef typename Batch<element_type>::type batch_type;
73 typedef typename PairRangeType<
75 typename DataContainer::element_range,
77 >::type element_range;
78 typedef typename PairRangeType<
80 typename DataContainer::const_element_range,
82 >::type const_element_range;
83 typedef typename PairRangeType<
85 typename DataContainer::batch_range,
88 typedef typename PairRangeType<
90 typename DataContainer::const_batch_range,
92 >::type const_batch_range;
95 typedef typename boost::range_reference<batch_range>::type batch_reference;
96 typedef typename boost::range_reference<const_batch_range>::type const_batch_reference;
97 typedef typename boost::range_reference<element_range>::type element_reference;
98 typedef typename boost::range_reference<const_element_range>::type const_element_reference;
104 const_element_range elements()
const{
105 return zipPairRange<element_type>(m_data.elements(),m_weights.elements());
111 element_range elements(){
112 return zipPairRange<element_type>(m_data.elements(),m_weights.elements());
119 const_batch_range batches()
const{
120 return zipPairRange<batch_type>(m_data.batches(),m_weights.batches());
126 batch_range batches(){
127 return zipPairRange<batch_type>(m_data.batches(),m_weights.batches());
131 std::size_t numberOfBatches()
const{
132 return m_data.numberOfBatches();
135 std::size_t numberOfElements()
const{
136 return m_data.numberOfElements();
141 return m_data.empty();
145 DataContainer
const& data()
const{
149 DataContainer& data(){
154 WeightContainer
const& weights()
const{
158 WeightContainer& weights(){
165 BaseWeightedDataset()
171 BaseWeightedDataset(std::size_t numBatches)
172 : m_data(numBatches),m_weights(numBatches)
182 BaseWeightedDataset(std::size_t
size, element_type
const& element, std::size_t batchSize)
183 : m_data(size,element.data,batchSize)
184 , m_weights(size,element.weight,batchSize)
191 BaseWeightedDataset(DataContainer
const& data, Data<WeightType>
const& weights)
192 : m_data(data), m_weights(weights)
194 SHARK_CHECK(data.numberOfElements() == weights.numberOfElements(),
"[ BaseWeightedDataset::WeightedUnlabeledData] number of data and number of weights must agree");
196 for(std::size_t i = 0; i != data.numberOfBatches(); ++i){
203 BaseWeightedDataset(DataContainer
const& data,
double weight)
204 : m_data(data), m_weights(data.numberOfBatches())
206 for(std::size_t i = 0; i != numberOfBatches(); ++i){
207 std::size_t batchSize =
boost::size(m_data.batch(i));
208 m_weights.batch(i) = Batch<WeightType>::type(batchSize,weight);
214 element_reference element(std::size_t i){
215 return element_reference(m_data.element(i),m_weights.element(i));
217 const_element_reference element(std::size_t i)
const{
218 return const_element_reference(m_data.element(i),m_weights.element(i));
222 batch_reference batch(std::size_t i){
223 return batch_reference(m_data.batch(i),m_weights.batch(i));
225 const_batch_reference batch(std::size_t i)
const{
226 return const_batch_reference(m_data.batch(i),m_weights.batch(i));
244 virtual void makeIndependent(){
245 m_weights.makeIndependent();
246 m_data.makeIndependent();
251 DiscreteUniform<Rng::rng_type>
uni(Rng::globalRng);
255 void splitBatch(std::size_t batch, std::size_t elementIndex){
256 m_data.splitBatch(batch,elementIndex);
257 m_weights.splitBatch(batch,elementIndex);
264 void append(self_type
const& other){
265 m_data.append(other.m_data);
266 m_weights.append(other.m_weights);
274 template<
class Range>
275 void repartition(
Range const& batchSizes){
276 m_data.repartition(batchSizes);
277 m_weights.repartition(batchSizes);
284 std::vector<std::size_t> getPartitioning()
const{
285 return m_data.getPartitioning();
288 friend void swap( self_type& a, self_type& b){
289 swap(a.m_data,b.m_data);
290 swap(a.m_weights,b.m_weights);
298 m_data.indexedSubset(indices,subset.m_data);
299 m_weights.indexedSubset(indices,subset.m_weights);
303 void indexedSubset(IndexSet
const& indices, self_type& subset, self_type& complement)
const{
305 detail::complement(indices,m_data.numberOfBatches(),comp);
306 m_data.indexedSubset(indices,subset.m_data);
307 m_weights.indexedSubset(indices,subset.m_weights);
308 m_data.indexedSubset(comp,complement.m_data);
309 m_weights.indexedSubset(comp,complement.m_weights);
312 DataContainer m_data;
313 WeightContainer m_weights;
330 template <
class DataT>
335 typedef detail::BaseWeightedDataset <UnlabeledData<DataT> > base_type;
337 using base_type::data;
338 using base_type::weights;
341 typedef typename base_type::element_type element_type;
356 : base_type(numBatches)
367 : base_type(size,element,batchSize){}
374 : base_type(data,weights)
379 : base_type(data,weight)
398 return self_type(data().splice(batch),weights().splice(batch));
402 swap(static_cast<base_type&>(a),static_cast<base_type&>(b));
408 std::ostream &operator << (std::ostream &stream, const WeightedUnlabeledData<T>& d) {
411 BOOST_FOREACH(reference elem,elements)
412 stream << elem.weight <<
" [" << elem.data<<
"]"<<
"\n";
417 template<
class DataRange,
class WeightRange>
418 typename boost::disable_if<
419 boost::is_arithmetic<WeightRange>,
421 typename boost::range_value<DataRange>::type
425 "[createDataFromRange] number of data points and number of weights must agree");
426 typedef typename boost::range_value<DataRange>::type
Data;
455 template <
class InputT,
class LabelT>
460 typedef detail::BaseWeightedDataset <LabeledData<InputT,LabelT> > base_type;
466 typedef typename base_type::element_type element_type;
468 using base_type::data;
469 using base_type::weights;
483 : base_type(numBatches)
494 : base_type(size,element,batchSize){}
501 : base_type(data,weights)
506 : base_type(data,weight)
520 return data().labels();
524 return data().labels();
537 return self_type(data().splice(batch),weights().splice(batch));
540 friend void swap(self_type& a, self_type& b){
541 swap(static_cast<base_type&>(a),static_cast<base_type&>(b));
546 template<
class T,
class U>
547 std::ostream &operator << (std::ostream &stream, const WeightedLabeledData<T, U>& d) {
550 BOOST_FOREACH(reference elem,elements)
551 stream << elem.weight <<
" ("<< elem.data.label <<
" [" << elem.data.input<<
"] )"<<
"\n";
557 template<
class InputType>
559 double weightSum = 0;
560 for(std::size_t i = 0; i != dataset.numberOfBatches(); ++i){
561 weightSum +=
sum(dataset.batch(i).weight);
565 template<
class InputType,
class LabelType>
567 double weightSum = 0;
568 for(std::size_t i = 0; i != dataset.numberOfBatches(); ++i){
569 weightSum +=
sum(dataset.batch(i).weight);
584 template <
class InputType>
590 template <
class InputType,
class LabelType>
596 template <
class InputType,
class LabelType>
601 template <
class InputType>
607 template<
class InputType,
class LabelType>
615 template<
class InputRange,
class LabelRange,
class WeightRange>
616 typename boost::disable_if<
617 boost::is_arithmetic<WeightRange>,
619 typename boost::range_value<InputRange>::type,
620 typename boost::range_value<LabelRange>::type
622 >::type
createLabeledDataFromRange(InputRange
const& inputs, LabelRange
const& labels, WeightRange
const& weights, std::size_t batchSize = 0){
624 "[createDataFromRange] number of data points and number of weights must agree");
626 "[createDataFromRange] number of data points and number of weights must agree");
627 typedef typename boost::range_value<InputRange>::type
InputType;
628 typedef typename boost::range_value<LabelRange>::type LabelType;
648 template<
class InputType,
class LabelType>
651 std::size_t bootStrapSize = 0
653 if(bootStrapSize == 0)
658 for(std::size_t i = 0; i != bootStrapSize; ++i){
660 bootstrapSet.element(index).weight += 1.0;
674 template<
class InputType>
677 std::size_t bootStrapSize = 0
679 if(bootStrapSize == 0)
684 for(std::size_t i = 0; i != bootStrapSize; ++i){
686 bootstrapSet.element(index).weight += 1.0;