shark::CARTTrainer Class Reference

Classification And Regression Trees CART. More...

#include <shark/Algorithms/Trainers/CARTTrainer.h>

+ Inheritance diagram for shark::CARTTrainer:

Classes

struct  TableEntry
 Types frequently used. More...
 

Public Types

typedef CARTClassifier< RealVector > ModelType
 
- Public Types inherited from shark::AbstractTrainer< CARTClassifier< RealVector >, unsigned int >
typedef CARTClassifier< RealVector > ModelType
 
typedef ModelType::InputType InputType
 
typedef unsigned int LabelType
 
typedef LabeledData< InputType, LabelTypeDatasetType
 
- Public Types inherited from shark::AbstractTrainer< CARTClassifier< RealVector >, RealVector >
typedef CARTClassifier< RealVector > ModelType
 
typedef ModelType::InputType InputType
 
typedef RealVector LabelType
 
typedef LabeledData< InputType, LabelTypeDatasetType
 

Public Member Functions

 CARTTrainer ()
 Constructor. More...
 
std::string name () const
 From INameable: return the class name. More...
 
SHARK_EXPORT_SYMBOL void train (ModelType &model, ClassificationDataset const &dataset)
 Train classification. More...
 
SHARK_EXPORT_SYMBOL void train (ModelType &model, RegressionDataset const &dataset)
 Train regression. More...
 
void setNumberOfFolds (unsigned int folds)
 Sets the number of folds used for creation of the trees. More...
 
- Public Member Functions inherited from shark::AbstractTrainer< CARTClassifier< RealVector >, unsigned int >
virtual void train (ModelType &model, DatasetType const &dataset)=0
 Core of the Trainer interface. More...
 
- Public Member Functions inherited from shark::INameable
virtual ~INameable ()
 
- Public Member Functions inherited from shark::ISerializable
virtual ~ISerializable ()
 Virtual d'tor. More...
 
virtual void read (InArchive &archive)
 Read the component from the supplied archive. More...
 
virtual void write (OutArchive &archive) const
 Write the component to the supplied archive. More...
 
void load (InArchive &archive, unsigned int version)
 Versioned loading of components, calls read(...). More...
 
void save (OutArchive &archive, unsigned int version) const
 Versioned storing of components, calls write(...). More...
 
 BOOST_SERIALIZATION_SPLIT_MEMBER ()
 
- Public Member Functions inherited from shark::AbstractTrainer< CARTClassifier< RealVector >, RealVector >
virtual void train (ModelType &model, DatasetType const &dataset)=0
 Core of the Trainer interface. More...
 

Protected Types

typedef std::vector< TableEntryAttributeTable
 
typedef std::vector< AttributeTableAttributeTables
 
typedef ModelType::SplitMatrixType SplitMatrixType
 

Protected Member Functions

SHARK_EXPORT_SYMBOL SplitMatrixType buildTree (AttributeTables const &tables, ClassificationDataset const &dataset, boost::unordered_map< std::size_t, std::size_t > &cAbove, std::size_t nodeId)
 
SHARK_EXPORT_SYMBOL double gini (boost::unordered_map< std::size_t, std::size_t > &countMatrix, std::size_t n)
 
SHARK_EXPORT_SYMBOL RealVector hist (boost::unordered_map< std::size_t, std::size_t > countMatrix)
 Creates a histogram from the count matrix. More...
 
SHARK_EXPORT_SYMBOL SplitMatrixType buildTree (AttributeTables const &tables, RegressionDataset const &dataset, std::vector< RealVector > const &labels, std::size_t nodeId, std::size_t trainSize)
 Regression functions. More...
 
SHARK_EXPORT_SYMBOL double totalSumOfSquares (std::vector< RealVector > const &labels, std::size_t start, std::size_t length, const RealVector &sumLabel)
 Calculates the total sum of squares. More...
 
SHARK_EXPORT_SYMBOL RealVector mean (std::vector< RealVector > const &labels)
 Calculates the mean of a vector of labels. More...
 
SHARK_EXPORT_SYMBOL void pruneMatrix (SplitMatrixType &splitMatrix)
 
SHARK_EXPORT_SYMBOL void pruneNode (SplitMatrixType &splitMatrix, std::size_t nodeId)
 Prunes a single node, including the child nodes of the decision tree. More...
 
SHARK_EXPORT_SYMBOL void measureStrenght (SplitMatrixType &splitMatrix, std::size_t nodeId, std::size_t parentNodeId)
 Updates the node variables used in the cost complexity pruning stage. More...
 
SHARK_EXPORT_SYMBOL std::size_t findNode (SplitMatrixType &splitMatrix, std::size_t nodeId)
 Returns the index of the node with node id in splitMatrix. More...
 
SHARK_EXPORT_SYMBOL AttributeTables createAttributeTables (Data< RealVector > const &dataset)
 
SHARK_EXPORT_SYMBOL void splitAttributeTables (AttributeTables const &tables, std::size_t index, std::size_t valIndex, AttributeTables &LAttributeTables, AttributeTables &RAttributeTables)
 Splits the attribute tables by a attribute index and value. Returns a left and a right attribute table in the variables LAttributeTables and RAttributeTables. More...
 
SHARK_EXPORT_SYMBOL boost::unordered_map< std::size_t, std::size_t > createCountMatrix (ClassificationDataset const &dataset)
 Crates count matrices from a classification dataset. More...
 

Protected Attributes

std::size_t m_inputDimension
 Number of attributes in the dataset. More...
 
std::size_t m_labelDimension
 Size of labels. More...
 
std::size_t m_nodeSize
 Controls the number of samples in the terminal nodes. More...
 
unsigned int m_maxLabel
 Holds the maximum label. Used in allocating the histograms. More...
 
unsigned int m_numberOfFolds
 Number of folds used to create the tree. More...
 

Detailed Description

Classification And Regression Trees CART.

CART is a decision tree algorithm, that builds a binary decision tree The decision tree is built by partitioning a dataset recursively

The partitioning is done, so that the partition chosen at a single node, is the partition the produces the largest decrease in node impurity.

The node impurity is measured by the Gini criteria in the classification case, and the total sum of squares error in the regression case

The tree is grown, until all leafs are pure. A node is considered pure when it only consist of identical cases in the classification case and identical or single values in the regression case

After the maximum sized tree is grown, the tree is pruned back from the leafs upward. The pruning is done by cost complexity pruning, as described by L. Breiman

The algorithm used is based on the SPRINT algorithm, as shown by J. Shafer et al.

For more detailed information about CART, see Classification And Regression Trees written by L. Breiman et al. 1984.

Definition at line 70 of file CARTTrainer.h.

Member Typedef Documentation

§ AttributeTable

typedef std::vector< TableEntry > shark::CARTTrainer::AttributeTable
protected

Definition at line 108 of file CARTTrainer.h.

§ AttributeTables

typedef std::vector< AttributeTable > shark::CARTTrainer::AttributeTables
protected

Definition at line 109 of file CARTTrainer.h.

§ ModelType

Definition at line 75 of file CARTTrainer.h.

§ SplitMatrixType

Constructor & Destructor Documentation

§ CARTTrainer()

shark::CARTTrainer::CARTTrainer ( )
inline

Constructor.

Definition at line 78 of file CARTTrainer.h.

References m_nodeSize, and m_numberOfFolds.

Member Function Documentation

§ buildTree() [1/2]

SHARK_EXPORT_SYMBOL SplitMatrixType shark::CARTTrainer::buildTree ( AttributeTables const &  tables,
ClassificationDataset const &  dataset,
boost::unordered_map< std::size_t, std::size_t > &  cAbove,
std::size_t  nodeId 
)
protected

Builds a single decision tree from a classification dataset The method requires the attribute tables,

§ buildTree() [2/2]

SHARK_EXPORT_SYMBOL SplitMatrixType shark::CARTTrainer::buildTree ( AttributeTables const &  tables,
RegressionDataset const &  dataset,
std::vector< RealVector > const &  labels,
std::size_t  nodeId,
std::size_t  trainSize 
)
protected

Regression functions.

§ createAttributeTables()

SHARK_EXPORT_SYMBOL AttributeTables shark::CARTTrainer::createAttributeTables ( Data< RealVector > const &  dataset)
protected

Attribute table functions Create the attribute tables used by the SPRINT algorithm

§ createCountMatrix()

SHARK_EXPORT_SYMBOL boost::unordered_map<std::size_t, std::size_t> shark::CARTTrainer::createCountMatrix ( ClassificationDataset const &  dataset)
protected

Crates count matrices from a classification dataset.

§ findNode()

SHARK_EXPORT_SYMBOL std::size_t shark::CARTTrainer::findNode ( SplitMatrixType splitMatrix,
std::size_t  nodeId 
)
protected

Returns the index of the node with node id in splitMatrix.

§ gini()

SHARK_EXPORT_SYMBOL double shark::CARTTrainer::gini ( boost::unordered_map< std::size_t, std::size_t > &  countMatrix,
std::size_t  n 
)
protected

Calculates the Gini impurity of a node. The impurity is defined as 1-sum_j p(j|t)^2 i.e the 1 minus the sum of the squared probability of observing class j in node t

§ hist()

SHARK_EXPORT_SYMBOL RealVector shark::CARTTrainer::hist ( boost::unordered_map< std::size_t, std::size_t >  countMatrix)
protected

Creates a histogram from the count matrix.

§ mean()

SHARK_EXPORT_SYMBOL RealVector shark::CARTTrainer::mean ( std::vector< RealVector > const &  labels)
protected

Calculates the mean of a vector of labels.

§ measureStrenght()

SHARK_EXPORT_SYMBOL void shark::CARTTrainer::measureStrenght ( SplitMatrixType splitMatrix,
std::size_t  nodeId,
std::size_t  parentNodeId 
)
protected

Updates the node variables used in the cost complexity pruning stage.

§ name()

std::string shark::CARTTrainer::name ( ) const
inlinevirtual

From INameable: return the class name.

Reimplemented from shark::INameable.

Definition at line 84 of file CARTTrainer.h.

References SHARK_EXPORT_SYMBOL, and train().

§ pruneMatrix()

SHARK_EXPORT_SYMBOL void shark::CARTTrainer::pruneMatrix ( SplitMatrixType splitMatrix)
protected

Pruning Prunes decision tree, represented by a split matrix

§ pruneNode()

SHARK_EXPORT_SYMBOL void shark::CARTTrainer::pruneNode ( SplitMatrixType splitMatrix,
std::size_t  nodeId 
)
protected

Prunes a single node, including the child nodes of the decision tree.

§ setNumberOfFolds()

void shark::CARTTrainer::setNumberOfFolds ( unsigned int  folds)
inline

Sets the number of folds used for creation of the trees.

Definition at line 94 of file CARTTrainer.h.

References m_numberOfFolds.

§ splitAttributeTables()

SHARK_EXPORT_SYMBOL void shark::CARTTrainer::splitAttributeTables ( AttributeTables const &  tables,
std::size_t  index,
std::size_t  valIndex,
AttributeTables LAttributeTables,
AttributeTables RAttributeTables 
)
protected

Splits the attribute tables by a attribute index and value. Returns a left and a right attribute table in the variables LAttributeTables and RAttributeTables.

§ totalSumOfSquares()

SHARK_EXPORT_SYMBOL double shark::CARTTrainer::totalSumOfSquares ( std::vector< RealVector > const &  labels,
std::size_t  start,
std::size_t  length,
const RealVector &  sumLabel 
)
protected

Calculates the total sum of squares.

§ train() [1/2]

SHARK_EXPORT_SYMBOL void shark::CARTTrainer::train ( ModelType model,
ClassificationDataset const &  dataset 
)

Train classification.

Referenced by main(), and name().

§ train() [2/2]

SHARK_EXPORT_SYMBOL void shark::CARTTrainer::train ( ModelType model,
RegressionDataset const &  dataset 
)

Train regression.

Member Data Documentation

§ m_inputDimension

std::size_t shark::CARTTrainer::m_inputDimension
protected

Number of attributes in the dataset.

Definition at line 115 of file CARTTrainer.h.

§ m_labelDimension

std::size_t shark::CARTTrainer::m_labelDimension
protected

Size of labels.

Definition at line 118 of file CARTTrainer.h.

§ m_maxLabel

unsigned int shark::CARTTrainer::m_maxLabel
protected

Holds the maximum label. Used in allocating the histograms.

Definition at line 124 of file CARTTrainer.h.

§ m_nodeSize

std::size_t shark::CARTTrainer::m_nodeSize
protected

Controls the number of samples in the terminal nodes.

Definition at line 121 of file CARTTrainer.h.

Referenced by CARTTrainer().

§ m_numberOfFolds

unsigned int shark::CARTTrainer::m_numberOfFolds
protected

Number of folds used to create the tree.

Definition at line 127 of file CARTTrainer.h.

Referenced by CARTTrainer(), and setNumberOfFolds().


The documentation for this class was generated from the following file: