Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
include
shark
Algorithms
Trainers
RFTrainer.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Random Forest Trainer
6
*
7
*
8
*
9
* \author K. N. Hansen, J. Kremer
10
* \date 2011-2012
11
*
12
*
13
* \par Copyright 1995-2015 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <http://image.diku.dk/shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
35
36
#ifndef SHARK_ALGORITHMS_TRAINERS_RFTRAINER_H
37
#define SHARK_ALGORITHMS_TRAINERS_RFTRAINER_H
38
39
#include <
shark/Core/DLLSupport.h
>
40
#include <
shark/Algorithms/Trainers/AbstractTrainer.h
>
41
#include <
shark/Models/Trees/RFClassifier.h
>
42
43
#include <boost/unordered_map.hpp>
44
#include <set>
45
46
namespace
shark
{
47
/*!
48
* \brief Random Forest
49
*
50
* Random Forest is an ensemble learner, that builds multiple binary decision trees.
51
* The trees are built using a variant of the CART methodology
52
*
53
* The algorithm used to generate each tree based on the SPRINT algorithm, as
54
* shown by J. Shafer et al.
55
*
56
* Typically 100+ trees are built, and classification/regression is done by combining
57
* the results generated by each tree. Typically the a majority vote is used in the
58
* classification case, and the mean is used in the regression case
59
*
60
* Each tree is built based on a random subset of the total dataset. Furthermore
61
* at each split, only a random subset of the attributes are investigated for
62
* the best split
63
*
64
* The node impurity is measured by the Gini criteria in the classification
65
* case, and the total sum of squared errors in the regression case
66
*
67
* After growing a maximum sized tree, the tree is added to the ensemble
68
* without pruning.
69
*
70
* For detailed information about Random Forest, see Random Forest
71
* by L. Breiman et al. 2001.
72
*
73
* For detailed information about the SPRINT algorithm, see
74
* SPRINT: A Scalable Parallel Classifier for Data Mining
75
* by J. Shafer et al.
76
*/
77
class
RFTrainer
78
:
public
AbstractTrainer
<RFClassifier, unsigned int>
79
,
public
AbstractTrainer
<RFClassifier>,
80
public
IParameterizable
81
{
82
83
public
:
84
/// Construct and compute feature importances when training or not
85
SHARK_EXPORT_SYMBOL
RFTrainer
(
bool
computeFeatureImportances =
false
,
bool
computeOOBerror =
false
);
86
87
/// \brief From INameable: return the class name.
88
std::string
name
()
const
89
{
return
"RFTrainer"
; }
90
91
/// Train a random forest for classification.
92
SHARK_EXPORT_SYMBOL
void
train
(
RFClassifier
& model,
const
ClassificationDataset
& dataset);
93
94
/// Train a random forest for regression.
95
SHARK_EXPORT_SYMBOL
void
train
(
RFClassifier
& model,
const
RegressionDataset
& dataset);
96
97
/// Set the number of random attributes to investigate at each node.
98
SHARK_EXPORT_SYMBOL
void
setMTry
(std::size_t mtry);
99
100
/// Set the number of trees to grow.
101
SHARK_EXPORT_SYMBOL
void
setNTrees
(std::size_t nTrees);
102
103
/// Controls when a node is considered pure. If set to 1, a node is pure
104
/// when it only consists of a single node.
105
SHARK_EXPORT_SYMBOL
void
setNodeSize
(std::size_t nTrees);
106
107
/// Set the fraction of the original training dataset to use as the
108
/// out of bag sample. The default value is 0.66.
109
SHARK_EXPORT_SYMBOL
void
setOOBratio
(
double
ratio);
110
111
/// Return the parameter vector.
112
RealVector
parameterVector
()
const
113
{
114
RealVector ret(1);
// number of trees
115
init
(ret) << (double)
m_B
;
116
return
ret;
117
}
118
119
/// Set the parameter vector.
120
void
setParameterVector
(RealVector
const
& newParameters)
121
{
122
SHARK_ASSERT
(newParameters.size() ==
numberOfParameters
());
123
setNTrees
((
size_t
) newParameters[0]);
124
}
125
126
protected
:
127
struct
RFAttribute
{
128
double
value
;
129
std::size_t
id
;
130
};
131
132
/// attribute table
133
typedef
std::vector < RFAttribute >
AttributeTable
;
134
/// collecting of attribute tables
135
typedef
std::vector < AttributeTable >
AttributeTables
;
136
137
/// Create attribute tables from a data set, and in the process create a count matrix (cAbove).
138
/// A dataset with m features results in m attribute tables.
139
/// [attribute | class/value | row id ]
140
SHARK_EXPORT_SYMBOL
void
createAttributeTables
(
Data<RealVector>
const
& dataset, AttributeTables& tables);
141
142
/// Create a count matrix as used in the classification case.
143
SHARK_EXPORT_SYMBOL
void
createCountMatrix
(
const
ClassificationDataset
& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove);
144
145
// Split attribute tables into left and right parts.
146
SHARK_EXPORT_SYMBOL
void
splitAttributeTables
(
const
AttributeTables& tables, std::size_t index, std::size_t valIndex, AttributeTables& LAttributeTables, AttributeTables& RAttributeTables);
147
148
/// Build a decision tree for classification
149
SHARK_EXPORT_SYMBOL
CARTClassifier<RealVector>::SplitMatrixType
buildTree
(AttributeTables& tables,
const
ClassificationDataset
& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId);
150
151
/// Builds a decision tree for regression
152
SHARK_EXPORT_SYMBOL
CARTClassifier<RealVector>::SplitMatrixType
buildTree
(AttributeTables& tables,
const
RegressionDataset
& dataset,
const
std::vector<RealVector>& labels, std::size_t nodeId);
153
154
/// comparison function for sorting an attributeTable
155
SHARK_EXPORT_SYMBOL
static
bool
tableSort
(
const
RFAttribute
& v1,
const
RFAttribute
& v2);
156
157
/// Generate a histogram from the count matrix.
158
SHARK_EXPORT_SYMBOL
RealVector
hist
(boost::unordered_map<std::size_t, std::size_t> countMatrix);
159
160
/// Average label over a vector.
161
SHARK_EXPORT_SYMBOL
RealVector
average
(
const
std::vector<RealVector>& labels);
162
163
/// Calculate the Gini impurity of the countMatrix
164
SHARK_EXPORT_SYMBOL
double
gini
(boost::unordered_map<std::size_t, std::size_t>& countMatrix, std::size_t n);
165
166
/// Total Sum Of Squares
167
SHARK_EXPORT_SYMBOL
double
totalSumOfSquares
(std::vector<RealVector>& labels, std::size_t from, std::size_t to,
const
RealVector& sumLabel);
168
169
/// Generate random table indices.
170
SHARK_EXPORT_SYMBOL
void
generateRandomTableIndicies
(std::set<std::size_t>& tableIndicies);
171
172
/// Reset the training to its default parameters.
173
SHARK_EXPORT_SYMBOL
void
setDefaults
();
174
175
/// Number of attributes in the dataset
176
std::size_t
m_inputDimension
;
177
178
/// size of labels
179
std::size_t
m_labelDimension
;
180
181
/// maximum size of the histogram;
182
/// classification case: maximum number of classes
183
unsigned
int
m_maxLabel
;
184
185
/// number of attributes to randomly test at each inner node
186
std::size_t
m_try
;
187
188
/// number of trees in the forest
189
std::size_t
m_B
;
190
191
/// number of samples in the terminal nodes
192
std::size_t
m_nodeSize
;
193
194
/// fraction of the data set used for growing trees
195
/// 0 < m_OOBratio < 1
196
double
m_OOBratio
;
197
198
/// true if the trainer is used for regression, false otherwise.
199
bool
m_regressionLearner
;
200
201
// true if the feature importances should be computed
202
bool
m_computeFeatureImportances
;
203
204
// true if OOB error should be computed
205
bool
m_computeOOBerror
;
206
};
207
}
208
#endif