Random Forest. More...
#include <shark/Algorithms/Trainers/RFTrainer.h>
Classes | |
struct | RFAttribute |
Public Member Functions | |
SHARK_EXPORT_SYMBOL | RFTrainer (bool computeFeatureImportances=false, bool computeOOBerror=false) |
Construct and compute feature importances when training or not. More... | |
std::string | name () const |
From INameable: return the class name. More... | |
SHARK_EXPORT_SYMBOL void | train (RFClassifier &model, const ClassificationDataset &dataset) |
Train a random forest for classification. More... | |
SHARK_EXPORT_SYMBOL void | train (RFClassifier &model, const RegressionDataset &dataset) |
Train a random forest for regression. More... | |
SHARK_EXPORT_SYMBOL void | setMTry (std::size_t mtry) |
Set the number of random attributes to investigate at each node. More... | |
SHARK_EXPORT_SYMBOL void | setNTrees (std::size_t nTrees) |
Set the number of trees to grow. More... | |
SHARK_EXPORT_SYMBOL void | setNodeSize (std::size_t nTrees) |
SHARK_EXPORT_SYMBOL void | setOOBratio (double ratio) |
RealVector | parameterVector () const |
Return the parameter vector. More... | |
void | setParameterVector (RealVector const &newParameters) |
Set the parameter vector. More... | |
![]() | |
virtual void | train (ModelType &model, DatasetType const &dataset)=0 |
Core of the Trainer interface. More... | |
![]() | |
virtual | ~INameable () |
![]() | |
virtual | ~ISerializable () |
Virtual d'tor. More... | |
virtual void | read (InArchive &archive) |
Read the component from the supplied archive. More... | |
virtual void | write (OutArchive &archive) const |
Write the component to the supplied archive. More... | |
void | load (InArchive &archive, unsigned int version) |
Versioned loading of components, calls read(...). More... | |
void | save (OutArchive &archive, unsigned int version) const |
Versioned storing of components, calls write(...). More... | |
BOOST_SERIALIZATION_SPLIT_MEMBER () | |
![]() | |
virtual void | train (ModelType &model, DatasetType const &dataset)=0 |
Core of the Trainer interface. More... | |
![]() | |
virtual | ~IParameterizable () |
virtual std::size_t | numberOfParameters () const |
Return the number of parameters. More... | |
Protected Types | |
typedef std::vector< RFAttribute > | AttributeTable |
attribute table More... | |
typedef std::vector< AttributeTable > | AttributeTables |
collecting of attribute tables More... | |
Protected Member Functions | |
SHARK_EXPORT_SYMBOL void | createAttributeTables (Data< RealVector > const &dataset, AttributeTables &tables) |
SHARK_EXPORT_SYMBOL void | createCountMatrix (const ClassificationDataset &dataset, boost::unordered_map< std::size_t, std::size_t > &cAbove) |
Create a count matrix as used in the classification case. More... | |
SHARK_EXPORT_SYMBOL void | splitAttributeTables (const AttributeTables &tables, std::size_t index, std::size_t valIndex, AttributeTables &LAttributeTables, AttributeTables &RAttributeTables) |
SHARK_EXPORT_SYMBOL CARTClassifier< RealVector >::SplitMatrixType | buildTree (AttributeTables &tables, const ClassificationDataset &dataset, boost::unordered_map< std::size_t, std::size_t > &cAbove, std::size_t nodeId) |
Build a decision tree for classification. More... | |
SHARK_EXPORT_SYMBOL CARTClassifier< RealVector >::SplitMatrixType | buildTree (AttributeTables &tables, const RegressionDataset &dataset, const std::vector< RealVector > &labels, std::size_t nodeId) |
Builds a decision tree for regression. More... | |
SHARK_EXPORT_SYMBOL RealVector | hist (boost::unordered_map< std::size_t, std::size_t > countMatrix) |
Generate a histogram from the count matrix. More... | |
SHARK_EXPORT_SYMBOL RealVector | average (const std::vector< RealVector > &labels) |
Average label over a vector. More... | |
SHARK_EXPORT_SYMBOL double | gini (boost::unordered_map< std::size_t, std::size_t > &countMatrix, std::size_t n) |
Calculate the Gini impurity of the countMatrix. More... | |
SHARK_EXPORT_SYMBOL double | totalSumOfSquares (std::vector< RealVector > &labels, std::size_t from, std::size_t to, const RealVector &sumLabel) |
Total Sum Of Squares. More... | |
SHARK_EXPORT_SYMBOL void | generateRandomTableIndicies (std::set< std::size_t > &tableIndicies) |
Generate random table indices. More... | |
SHARK_EXPORT_SYMBOL void | setDefaults () |
Reset the training to its default parameters. More... | |
Static Protected Member Functions | |
static SHARK_EXPORT_SYMBOL bool | tableSort (const RFAttribute &v1, const RFAttribute &v2) |
comparison function for sorting an attributeTable More... | |
Protected Attributes | |
std::size_t | m_inputDimension |
Number of attributes in the dataset. More... | |
std::size_t | m_labelDimension |
size of labels More... | |
unsigned int | m_maxLabel |
std::size_t | m_try |
number of attributes to randomly test at each inner node More... | |
std::size_t | m_B |
number of trees in the forest More... | |
std::size_t | m_nodeSize |
number of samples in the terminal nodes More... | |
double | m_OOBratio |
bool | m_regressionLearner |
true if the trainer is used for regression, false otherwise. More... | |
bool | m_computeFeatureImportances |
bool | m_computeOOBerror |
Additional Inherited Members | |
![]() | |
typedef RFClassifier | ModelType |
typedef ModelType::InputType | InputType |
typedef unsigned int | LabelType |
typedef LabeledData< InputType, LabelType > | DatasetType |
![]() | |
typedef RFClassifier | ModelType |
typedef ModelType::InputType | InputType |
typedef typename RFClassifier ::OutputType | LabelType |
typedef LabeledData< InputType, LabelType > | DatasetType |
Random Forest.
Random Forest is an ensemble learner, that builds multiple binary decision trees. The trees are built using a variant of the CART methodology
The algorithm used to generate each tree based on the SPRINT algorithm, as shown by J. Shafer et al.
Typically 100+ trees are built, and classification/regression is done by combining the results generated by each tree. Typically the a majority vote is used in the classification case, and the mean is used in the regression case
Each tree is built based on a random subset of the total dataset. Furthermore at each split, only a random subset of the attributes are investigated for the best split
The node impurity is measured by the Gini criteria in the classification case, and the total sum of squared errors in the regression case
After growing a maximum sized tree, the tree is added to the ensemble without pruning.
For detailed information about Random Forest, see Random Forest by L. Breiman et al. 2001.
For detailed information about the SPRINT algorithm, see SPRINT: A Scalable Parallel Classifier for Data Mining by J. Shafer et al.
Definition at line 77 of file RFTrainer.h.
|
protected |
attribute table
Definition at line 133 of file RFTrainer.h.
|
protected |
collecting of attribute tables
Definition at line 135 of file RFTrainer.h.
SHARK_EXPORT_SYMBOL shark::RFTrainer::RFTrainer | ( | bool | computeFeatureImportances = false , |
bool | computeOOBerror = false |
||
) |
Construct and compute feature importances when training or not.
|
protected |
Average label over a vector.
|
protected |
Build a decision tree for classification.
|
protected |
Builds a decision tree for regression.
|
protected |
Create attribute tables from a data set, and in the process create a count matrix (cAbove). A dataset with m features results in m attribute tables. [attribute | class/value | row id ]
|
protected |
Create a count matrix as used in the classification case.
|
protected |
Generate random table indices.
|
protected |
Calculate the Gini impurity of the countMatrix.
|
protected |
Generate a histogram from the count matrix.
|
inlinevirtual |
From INameable: return the class name.
Reimplemented from shark::INameable.
Definition at line 88 of file RFTrainer.h.
References setMTry(), setNodeSize(), setNTrees(), setOOBratio(), SHARK_EXPORT_SYMBOL, and train().
Referenced by main().
|
inlinevirtual |
Return the parameter vector.
Reimplemented from shark::IParameterizable.
Definition at line 112 of file RFTrainer.h.
References shark::blas::init(), and m_B.
|
protected |
Reset the training to its default parameters.
SHARK_EXPORT_SYMBOL void shark::RFTrainer::setMTry | ( | std::size_t | mtry | ) |
Set the number of random attributes to investigate at each node.
Referenced by name().
SHARK_EXPORT_SYMBOL void shark::RFTrainer::setNodeSize | ( | std::size_t | nTrees | ) |
Controls when a node is considered pure. If set to 1, a node is pure when it only consists of a single node.
Referenced by name().
SHARK_EXPORT_SYMBOL void shark::RFTrainer::setNTrees | ( | std::size_t | nTrees | ) |
Set the number of trees to grow.
Referenced by name(), and setParameterVector().
SHARK_EXPORT_SYMBOL void shark::RFTrainer::setOOBratio | ( | double | ratio | ) |
Set the fraction of the original training dataset to use as the out of bag sample. The default value is 0.66.
Referenced by name().
|
inlinevirtual |
Set the parameter vector.
Reimplemented from shark::IParameterizable.
Definition at line 120 of file RFTrainer.h.
References shark::IParameterizable::numberOfParameters(), setNTrees(), and SHARK_ASSERT.
|
protected |
|
staticprotected |
comparison function for sorting an attributeTable
|
protected |
Total Sum Of Squares.
SHARK_EXPORT_SYMBOL void shark::RFTrainer::train | ( | RFClassifier & | model, |
const ClassificationDataset & | dataset | ||
) |
SHARK_EXPORT_SYMBOL void shark::RFTrainer::train | ( | RFClassifier & | model, |
const RegressionDataset & | dataset | ||
) |
Train a random forest for regression.
|
protected |
number of trees in the forest
Definition at line 189 of file RFTrainer.h.
Referenced by parameterVector().
|
protected |
Definition at line 202 of file RFTrainer.h.
|
protected |
Definition at line 205 of file RFTrainer.h.
|
protected |
Number of attributes in the dataset.
Definition at line 176 of file RFTrainer.h.
|
protected |
size of labels
Definition at line 179 of file RFTrainer.h.
|
protected |
maximum size of the histogram; classification case: maximum number of classes
Definition at line 183 of file RFTrainer.h.
|
protected |
number of samples in the terminal nodes
Definition at line 192 of file RFTrainer.h.
|
protected |
fraction of the data set used for growing trees 0 < m_OOBratio < 1
Definition at line 196 of file RFTrainer.h.
|
protected |
true if the trainer is used for regression, false otherwise.
Definition at line 199 of file RFTrainer.h.
|
protected |
number of attributes to randomly test at each inner node
Definition at line 186 of file RFTrainer.h.