Classification And Regression Trees CART. More...
#include <shark/Algorithms/Trainers/CARTTrainer.h>
Classes | |
struct | TableEntry |
Types frequently used. More... | |
Public Types | |
typedef CARTClassifier< RealVector > | ModelType |
![]() | |
typedef CARTClassifier< RealVector > | ModelType |
typedef ModelType::InputType | InputType |
typedef unsigned int | LabelType |
typedef LabeledData< InputType, LabelType > | DatasetType |
![]() | |
typedef CARTClassifier< RealVector > | ModelType |
typedef ModelType::InputType | InputType |
typedef RealVector | LabelType |
typedef LabeledData< InputType, LabelType > | DatasetType |
Public Member Functions | |
CARTTrainer () | |
Constructor. More... | |
std::string | name () const |
From INameable: return the class name. More... | |
SHARK_EXPORT_SYMBOL void | train (ModelType &model, ClassificationDataset const &dataset) |
Train classification. More... | |
SHARK_EXPORT_SYMBOL void | train (ModelType &model, RegressionDataset const &dataset) |
Train regression. More... | |
void | setNumberOfFolds (unsigned int folds) |
Sets the number of folds used for creation of the trees. More... | |
![]() | |
virtual void | train (ModelType &model, DatasetType const &dataset)=0 |
Core of the Trainer interface. More... | |
![]() | |
virtual | ~INameable () |
![]() | |
virtual | ~ISerializable () |
Virtual d'tor. More... | |
virtual void | read (InArchive &archive) |
Read the component from the supplied archive. More... | |
virtual void | write (OutArchive &archive) const |
Write the component to the supplied archive. More... | |
void | load (InArchive &archive, unsigned int version) |
Versioned loading of components, calls read(...). More... | |
void | save (OutArchive &archive, unsigned int version) const |
Versioned storing of components, calls write(...). More... | |
BOOST_SERIALIZATION_SPLIT_MEMBER () | |
![]() | |
virtual void | train (ModelType &model, DatasetType const &dataset)=0 |
Core of the Trainer interface. More... | |
Protected Types | |
typedef std::vector< TableEntry > | AttributeTable |
typedef std::vector< AttributeTable > | AttributeTables |
typedef ModelType::SplitMatrixType | SplitMatrixType |
Protected Member Functions | |
SHARK_EXPORT_SYMBOL SplitMatrixType | buildTree (AttributeTables const &tables, ClassificationDataset const &dataset, boost::unordered_map< std::size_t, std::size_t > &cAbove, std::size_t nodeId) |
SHARK_EXPORT_SYMBOL double | gini (boost::unordered_map< std::size_t, std::size_t > &countMatrix, std::size_t n) |
SHARK_EXPORT_SYMBOL RealVector | hist (boost::unordered_map< std::size_t, std::size_t > countMatrix) |
Creates a histogram from the count matrix. More... | |
SHARK_EXPORT_SYMBOL SplitMatrixType | buildTree (AttributeTables const &tables, RegressionDataset const &dataset, std::vector< RealVector > const &labels, std::size_t nodeId, std::size_t trainSize) |
Regression functions. More... | |
SHARK_EXPORT_SYMBOL double | totalSumOfSquares (std::vector< RealVector > const &labels, std::size_t start, std::size_t length, const RealVector &sumLabel) |
Calculates the total sum of squares. More... | |
SHARK_EXPORT_SYMBOL RealVector | mean (std::vector< RealVector > const &labels) |
Calculates the mean of a vector of labels. More... | |
SHARK_EXPORT_SYMBOL void | pruneMatrix (SplitMatrixType &splitMatrix) |
SHARK_EXPORT_SYMBOL void | pruneNode (SplitMatrixType &splitMatrix, std::size_t nodeId) |
Prunes a single node, including the child nodes of the decision tree. More... | |
SHARK_EXPORT_SYMBOL void | measureStrenght (SplitMatrixType &splitMatrix, std::size_t nodeId, std::size_t parentNodeId) |
Updates the node variables used in the cost complexity pruning stage. More... | |
SHARK_EXPORT_SYMBOL std::size_t | findNode (SplitMatrixType &splitMatrix, std::size_t nodeId) |
Returns the index of the node with node id in splitMatrix. More... | |
SHARK_EXPORT_SYMBOL AttributeTables | createAttributeTables (Data< RealVector > const &dataset) |
SHARK_EXPORT_SYMBOL void | splitAttributeTables (AttributeTables const &tables, std::size_t index, std::size_t valIndex, AttributeTables &LAttributeTables, AttributeTables &RAttributeTables) |
Splits the attribute tables by a attribute index and value. Returns a left and a right attribute table in the variables LAttributeTables and RAttributeTables. More... | |
SHARK_EXPORT_SYMBOL boost::unordered_map< std::size_t, std::size_t > | createCountMatrix (ClassificationDataset const &dataset) |
Crates count matrices from a classification dataset. More... | |
Protected Attributes | |
std::size_t | m_inputDimension |
Number of attributes in the dataset. More... | |
std::size_t | m_labelDimension |
Size of labels. More... | |
std::size_t | m_nodeSize |
Controls the number of samples in the terminal nodes. More... | |
unsigned int | m_maxLabel |
Holds the maximum label. Used in allocating the histograms. More... | |
unsigned int | m_numberOfFolds |
Number of folds used to create the tree. More... | |
Classification And Regression Trees CART.
CART is a decision tree algorithm, that builds a binary decision tree The decision tree is built by partitioning a dataset recursively
The partitioning is done, so that the partition chosen at a single node, is the partition the produces the largest decrease in node impurity.
The node impurity is measured by the Gini criteria in the classification case, and the total sum of squares error in the regression case
The tree is grown, until all leafs are pure. A node is considered pure when it only consist of identical cases in the classification case and identical or single values in the regression case
After the maximum sized tree is grown, the tree is pruned back from the leafs upward. The pruning is done by cost complexity pruning, as described by L. Breiman
The algorithm used is based on the SPRINT algorithm, as shown by J. Shafer et al.
For more detailed information about CART, see Classification And Regression Trees written by L. Breiman et al. 1984.
Definition at line 70 of file CARTTrainer.h.
|
protected |
Definition at line 108 of file CARTTrainer.h.
|
protected |
Definition at line 109 of file CARTTrainer.h.
typedef CARTClassifier<RealVector> shark::CARTTrainer::ModelType |
Definition at line 75 of file CARTTrainer.h.
|
protected |
Definition at line 111 of file CARTTrainer.h.
|
inline |
Constructor.
Definition at line 78 of file CARTTrainer.h.
References m_nodeSize, and m_numberOfFolds.
|
protected |
Builds a single decision tree from a classification dataset The method requires the attribute tables,
|
protected |
Regression functions.
|
protected |
Attribute table functions Create the attribute tables used by the SPRINT algorithm
|
protected |
Crates count matrices from a classification dataset.
|
protected |
Returns the index of the node with node id in splitMatrix.
|
protected |
Calculates the Gini impurity of a node. The impurity is defined as 1-sum_j p(j|t)^2 i.e the 1 minus the sum of the squared probability of observing class j in node t
|
protected |
Creates a histogram from the count matrix.
|
protected |
Calculates the mean of a vector of labels.
|
protected |
Updates the node variables used in the cost complexity pruning stage.
|
inlinevirtual |
From INameable: return the class name.
Reimplemented from shark::INameable.
Definition at line 84 of file CARTTrainer.h.
References SHARK_EXPORT_SYMBOL, and train().
|
protected |
Pruning Prunes decision tree, represented by a split matrix
|
protected |
Prunes a single node, including the child nodes of the decision tree.
|
inline |
Sets the number of folds used for creation of the trees.
Definition at line 94 of file CARTTrainer.h.
References m_numberOfFolds.
|
protected |
Splits the attribute tables by a attribute index and value. Returns a left and a right attribute table in the variables LAttributeTables and RAttributeTables.
|
protected |
Calculates the total sum of squares.
SHARK_EXPORT_SYMBOL void shark::CARTTrainer::train | ( | ModelType & | model, |
ClassificationDataset const & | dataset | ||
) |
SHARK_EXPORT_SYMBOL void shark::CARTTrainer::train | ( | ModelType & | model, |
RegressionDataset const & | dataset | ||
) |
Train regression.
|
protected |
Number of attributes in the dataset.
Definition at line 115 of file CARTTrainer.h.
|
protected |
Size of labels.
Definition at line 118 of file CARTTrainer.h.
|
protected |
Holds the maximum label. Used in allocating the histograms.
Definition at line 124 of file CARTTrainer.h.
|
protected |
Controls the number of samples in the terminal nodes.
Definition at line 121 of file CARTTrainer.h.
Referenced by CARTTrainer().
|
protected |
Number of folds used to create the tree.
Definition at line 127 of file CARTTrainer.h.
Referenced by CARTTrainer(), and setNumberOfFolds().