shark::RFTrainer Class Reference

Random Forest. More...

#include <shark/Algorithms/Trainers/RFTrainer.h>

+ Inheritance diagram for shark::RFTrainer:

Classes

struct  RFAttribute
 

Public Member Functions

SHARK_EXPORT_SYMBOL RFTrainer (bool computeFeatureImportances=false, bool computeOOBerror=false)
 Construct and compute feature importances when training or not. More...
 
std::string name () const
 From INameable: return the class name. More...
 
SHARK_EXPORT_SYMBOL void train (RFClassifier &model, const ClassificationDataset &dataset)
 Train a random forest for classification. More...
 
SHARK_EXPORT_SYMBOL void train (RFClassifier &model, const RegressionDataset &dataset)
 Train a random forest for regression. More...
 
SHARK_EXPORT_SYMBOL void setMTry (std::size_t mtry)
 Set the number of random attributes to investigate at each node. More...
 
SHARK_EXPORT_SYMBOL void setNTrees (std::size_t nTrees)
 Set the number of trees to grow. More...
 
SHARK_EXPORT_SYMBOL void setNodeSize (std::size_t nTrees)
 
SHARK_EXPORT_SYMBOL void setOOBratio (double ratio)
 
RealVector parameterVector () const
 Return the parameter vector. More...
 
void setParameterVector (RealVector const &newParameters)
 Set the parameter vector. More...
 
- Public Member Functions inherited from shark::AbstractTrainer< RFClassifier, unsigned int >
virtual void train (ModelType &model, DatasetType const &dataset)=0
 Core of the Trainer interface. More...
 
- Public Member Functions inherited from shark::INameable
virtual ~INameable ()
 
- Public Member Functions inherited from shark::ISerializable
virtual ~ISerializable ()
 Virtual d'tor. More...
 
virtual void read (InArchive &archive)
 Read the component from the supplied archive. More...
 
virtual void write (OutArchive &archive) const
 Write the component to the supplied archive. More...
 
void load (InArchive &archive, unsigned int version)
 Versioned loading of components, calls read(...). More...
 
void save (OutArchive &archive, unsigned int version) const
 Versioned storing of components, calls write(...). More...
 
 BOOST_SERIALIZATION_SPLIT_MEMBER ()
 
- Public Member Functions inherited from shark::AbstractTrainer< RFClassifier >
virtual void train (ModelType &model, DatasetType const &dataset)=0
 Core of the Trainer interface. More...
 
- Public Member Functions inherited from shark::IParameterizable
virtual ~IParameterizable ()
 
virtual std::size_t numberOfParameters () const
 Return the number of parameters. More...
 

Protected Types

typedef std::vector< RFAttributeAttributeTable
 attribute table More...
 
typedef std::vector< AttributeTableAttributeTables
 collecting of attribute tables More...
 

Protected Member Functions

SHARK_EXPORT_SYMBOL void createAttributeTables (Data< RealVector > const &dataset, AttributeTables &tables)
 
SHARK_EXPORT_SYMBOL void createCountMatrix (const ClassificationDataset &dataset, boost::unordered_map< std::size_t, std::size_t > &cAbove)
 Create a count matrix as used in the classification case. More...
 
SHARK_EXPORT_SYMBOL void splitAttributeTables (const AttributeTables &tables, std::size_t index, std::size_t valIndex, AttributeTables &LAttributeTables, AttributeTables &RAttributeTables)
 
SHARK_EXPORT_SYMBOL CARTClassifier< RealVector >::SplitMatrixType buildTree (AttributeTables &tables, const ClassificationDataset &dataset, boost::unordered_map< std::size_t, std::size_t > &cAbove, std::size_t nodeId)
 Build a decision tree for classification. More...
 
SHARK_EXPORT_SYMBOL CARTClassifier< RealVector >::SplitMatrixType buildTree (AttributeTables &tables, const RegressionDataset &dataset, const std::vector< RealVector > &labels, std::size_t nodeId)
 Builds a decision tree for regression. More...
 
SHARK_EXPORT_SYMBOL RealVector hist (boost::unordered_map< std::size_t, std::size_t > countMatrix)
 Generate a histogram from the count matrix. More...
 
SHARK_EXPORT_SYMBOL RealVector average (const std::vector< RealVector > &labels)
 Average label over a vector. More...
 
SHARK_EXPORT_SYMBOL double gini (boost::unordered_map< std::size_t, std::size_t > &countMatrix, std::size_t n)
 Calculate the Gini impurity of the countMatrix. More...
 
SHARK_EXPORT_SYMBOL double totalSumOfSquares (std::vector< RealVector > &labels, std::size_t from, std::size_t to, const RealVector &sumLabel)
 Total Sum Of Squares. More...
 
SHARK_EXPORT_SYMBOL void generateRandomTableIndicies (std::set< std::size_t > &tableIndicies)
 Generate random table indices. More...
 
SHARK_EXPORT_SYMBOL void setDefaults ()
 Reset the training to its default parameters. More...
 

Static Protected Member Functions

static SHARK_EXPORT_SYMBOL bool tableSort (const RFAttribute &v1, const RFAttribute &v2)
 comparison function for sorting an attributeTable More...
 

Protected Attributes

std::size_t m_inputDimension
 Number of attributes in the dataset. More...
 
std::size_t m_labelDimension
 size of labels More...
 
unsigned int m_maxLabel
 
std::size_t m_try
 number of attributes to randomly test at each inner node More...
 
std::size_t m_B
 number of trees in the forest More...
 
std::size_t m_nodeSize
 number of samples in the terminal nodes More...
 
double m_OOBratio
 
bool m_regressionLearner
 true if the trainer is used for regression, false otherwise. More...
 
bool m_computeFeatureImportances
 
bool m_computeOOBerror
 

Additional Inherited Members

- Public Types inherited from shark::AbstractTrainer< RFClassifier, unsigned int >
typedef RFClassifier ModelType
 
typedef ModelType::InputType InputType
 
typedef unsigned int LabelType
 
typedef LabeledData< InputType, LabelTypeDatasetType
 
- Public Types inherited from shark::AbstractTrainer< RFClassifier >
typedef RFClassifier ModelType
 
typedef ModelType::InputType InputType
 
typedef typename RFClassifier ::OutputType LabelType
 
typedef LabeledData< InputType, LabelTypeDatasetType
 

Detailed Description

Random Forest.

Random Forest is an ensemble learner, that builds multiple binary decision trees. The trees are built using a variant of the CART methodology

The algorithm used to generate each tree based on the SPRINT algorithm, as shown by J. Shafer et al.

Typically 100+ trees are built, and classification/regression is done by combining the results generated by each tree. Typically the a majority vote is used in the classification case, and the mean is used in the regression case

Each tree is built based on a random subset of the total dataset. Furthermore at each split, only a random subset of the attributes are investigated for the best split

The node impurity is measured by the Gini criteria in the classification case, and the total sum of squared errors in the regression case

After growing a maximum sized tree, the tree is added to the ensemble without pruning.

For detailed information about Random Forest, see Random Forest by L. Breiman et al. 2001.

For detailed information about the SPRINT algorithm, see SPRINT: A Scalable Parallel Classifier for Data Mining by J. Shafer et al.

Definition at line 77 of file RFTrainer.h.

Member Typedef Documentation

§ AttributeTable

typedef std::vector< RFAttribute > shark::RFTrainer::AttributeTable
protected

attribute table

Definition at line 133 of file RFTrainer.h.

§ AttributeTables

typedef std::vector< AttributeTable > shark::RFTrainer::AttributeTables
protected

collecting of attribute tables

Definition at line 135 of file RFTrainer.h.

Constructor & Destructor Documentation

§ RFTrainer()

SHARK_EXPORT_SYMBOL shark::RFTrainer::RFTrainer ( bool  computeFeatureImportances = false,
bool  computeOOBerror = false 
)

Construct and compute feature importances when training or not.

Member Function Documentation

§ average()

SHARK_EXPORT_SYMBOL RealVector shark::RFTrainer::average ( const std::vector< RealVector > &  labels)
protected

Average label over a vector.

§ buildTree() [1/2]

SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::SplitMatrixType shark::RFTrainer::buildTree ( AttributeTables tables,
const ClassificationDataset dataset,
boost::unordered_map< std::size_t, std::size_t > &  cAbove,
std::size_t  nodeId 
)
protected

Build a decision tree for classification.

§ buildTree() [2/2]

SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::SplitMatrixType shark::RFTrainer::buildTree ( AttributeTables tables,
const RegressionDataset dataset,
const std::vector< RealVector > &  labels,
std::size_t  nodeId 
)
protected

Builds a decision tree for regression.

§ createAttributeTables()

SHARK_EXPORT_SYMBOL void shark::RFTrainer::createAttributeTables ( Data< RealVector > const &  dataset,
AttributeTables tables 
)
protected

Create attribute tables from a data set, and in the process create a count matrix (cAbove). A dataset with m features results in m attribute tables. [attribute | class/value | row id ]

§ createCountMatrix()

SHARK_EXPORT_SYMBOL void shark::RFTrainer::createCountMatrix ( const ClassificationDataset dataset,
boost::unordered_map< std::size_t, std::size_t > &  cAbove 
)
protected

Create a count matrix as used in the classification case.

§ generateRandomTableIndicies()

SHARK_EXPORT_SYMBOL void shark::RFTrainer::generateRandomTableIndicies ( std::set< std::size_t > &  tableIndicies)
protected

Generate random table indices.

§ gini()

SHARK_EXPORT_SYMBOL double shark::RFTrainer::gini ( boost::unordered_map< std::size_t, std::size_t > &  countMatrix,
std::size_t  n 
)
protected

Calculate the Gini impurity of the countMatrix.

§ hist()

SHARK_EXPORT_SYMBOL RealVector shark::RFTrainer::hist ( boost::unordered_map< std::size_t, std::size_t >  countMatrix)
protected

Generate a histogram from the count matrix.

§ name()

std::string shark::RFTrainer::name ( ) const
inlinevirtual

From INameable: return the class name.

Reimplemented from shark::INameable.

Definition at line 88 of file RFTrainer.h.

References setMTry(), setNodeSize(), setNTrees(), setOOBratio(), SHARK_EXPORT_SYMBOL, and train().

Referenced by main().

§ parameterVector()

RealVector shark::RFTrainer::parameterVector ( ) const
inlinevirtual

Return the parameter vector.

Reimplemented from shark::IParameterizable.

Definition at line 112 of file RFTrainer.h.

References shark::blas::init(), and m_B.

§ setDefaults()

SHARK_EXPORT_SYMBOL void shark::RFTrainer::setDefaults ( )
protected

Reset the training to its default parameters.

§ setMTry()

SHARK_EXPORT_SYMBOL void shark::RFTrainer::setMTry ( std::size_t  mtry)

Set the number of random attributes to investigate at each node.

Referenced by name().

§ setNodeSize()

SHARK_EXPORT_SYMBOL void shark::RFTrainer::setNodeSize ( std::size_t  nTrees)

Controls when a node is considered pure. If set to 1, a node is pure when it only consists of a single node.

Referenced by name().

§ setNTrees()

SHARK_EXPORT_SYMBOL void shark::RFTrainer::setNTrees ( std::size_t  nTrees)

Set the number of trees to grow.

Referenced by name(), and setParameterVector().

§ setOOBratio()

SHARK_EXPORT_SYMBOL void shark::RFTrainer::setOOBratio ( double  ratio)

Set the fraction of the original training dataset to use as the out of bag sample. The default value is 0.66.

Referenced by name().

§ setParameterVector()

void shark::RFTrainer::setParameterVector ( RealVector const &  newParameters)
inlinevirtual

Set the parameter vector.

Reimplemented from shark::IParameterizable.

Definition at line 120 of file RFTrainer.h.

References shark::IParameterizable::numberOfParameters(), setNTrees(), and SHARK_ASSERT.

§ splitAttributeTables()

SHARK_EXPORT_SYMBOL void shark::RFTrainer::splitAttributeTables ( const AttributeTables tables,
std::size_t  index,
std::size_t  valIndex,
AttributeTables LAttributeTables,
AttributeTables RAttributeTables 
)
protected

§ tableSort()

static SHARK_EXPORT_SYMBOL bool shark::RFTrainer::tableSort ( const RFAttribute v1,
const RFAttribute v2 
)
staticprotected

comparison function for sorting an attributeTable

§ totalSumOfSquares()

SHARK_EXPORT_SYMBOL double shark::RFTrainer::totalSumOfSquares ( std::vector< RealVector > &  labels,
std::size_t  from,
std::size_t  to,
const RealVector &  sumLabel 
)
protected

Total Sum Of Squares.

§ train() [1/2]

SHARK_EXPORT_SYMBOL void shark::RFTrainer::train ( RFClassifier model,
const ClassificationDataset dataset 
)

Train a random forest for classification.

Referenced by main(), and name().

§ train() [2/2]

SHARK_EXPORT_SYMBOL void shark::RFTrainer::train ( RFClassifier model,
const RegressionDataset dataset 
)

Train a random forest for regression.

Member Data Documentation

§ m_B

std::size_t shark::RFTrainer::m_B
protected

number of trees in the forest

Definition at line 189 of file RFTrainer.h.

Referenced by parameterVector().

§ m_computeFeatureImportances

bool shark::RFTrainer::m_computeFeatureImportances
protected

Definition at line 202 of file RFTrainer.h.

§ m_computeOOBerror

bool shark::RFTrainer::m_computeOOBerror
protected

Definition at line 205 of file RFTrainer.h.

§ m_inputDimension

std::size_t shark::RFTrainer::m_inputDimension
protected

Number of attributes in the dataset.

Definition at line 176 of file RFTrainer.h.

§ m_labelDimension

std::size_t shark::RFTrainer::m_labelDimension
protected

size of labels

Definition at line 179 of file RFTrainer.h.

§ m_maxLabel

unsigned int shark::RFTrainer::m_maxLabel
protected

maximum size of the histogram; classification case: maximum number of classes

Definition at line 183 of file RFTrainer.h.

§ m_nodeSize

std::size_t shark::RFTrainer::m_nodeSize
protected

number of samples in the terminal nodes

Definition at line 192 of file RFTrainer.h.

§ m_OOBratio

double shark::RFTrainer::m_OOBratio
protected

fraction of the data set used for growing trees 0 < m_OOBratio < 1

Definition at line 196 of file RFTrainer.h.

§ m_regressionLearner

bool shark::RFTrainer::m_regressionLearner
protected

true if the trainer is used for regression, false otherwise.

Definition at line 199 of file RFTrainer.h.

§ m_try

std::size_t shark::RFTrainer::m_try
protected

number of attributes to randomly test at each inner node

Definition at line 186 of file RFTrainer.h.


The documentation for this class was generated from the following file: