Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
include
shark
Algorithms
Trainers
CARTTrainer.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief CART
6
*
7
*
8
*
9
* \author K. N. Hansen
10
* \date 2012
11
*
12
*
13
* \par Copyright 1995-2015 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <http://image.diku.dk/shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
35
36
#ifndef SHARK_ALGORITHMS_TRAINERS_CARTTRAINER_H
37
#define SHARK_ALGORITHMS_TRAINERS_CARTTRAINER_H
38
39
#include <
shark/Core/DLLSupport.h
>
40
#include <
shark/Models/Trees/CARTClassifier.h
>
41
#include <
shark/Algorithms/Trainers/AbstractTrainer.h
>
42
#include <boost/unordered_map.hpp>
43
44
namespace
shark
{
45
/*!
46
* \brief Classification And Regression Trees CART
47
*
48
* CART is a decision tree algorithm, that builds a binary decision tree
49
* The decision tree is built by partitioning a dataset recursively
50
*
51
* The partitioning is done, so that the partition chosen at a single
52
* node, is the partition the produces the largest decrease in node
53
* impurity.
54
*
55
* The node impurity is measured by the Gini criteria in the classification
56
* case, and the total sum of squares error in the regression case
57
*
58
* The tree is grown, until all leafs are pure. A node is considered pure
59
* when it only consist of identical cases in the classification case
60
* and identical or single values in the regression case
61
*
62
* After the maximum sized tree is grown, the tree is pruned back from the leafs upward.
63
* The pruning is done by cost complexity pruning, as described by L. Breiman
64
*
65
* The algorithm used is based on the SPRINT algorithm, as shown by J. Shafer et al.
66
*
67
* For more detailed information about CART, see \e Classification \e And \e Regression
68
* \e Trees written by L. Breiman et al. 1984.
69
*/
70
class
CARTTrainer
71
:
public
AbstractTrainer
<CARTClassifier<RealVector>, unsigned int>
72
,
public
AbstractTrainer
<CARTClassifier<RealVector>, RealVector >
73
{
74
public
:
75
typedef
CARTClassifier<RealVector>
ModelType
;
76
77
/// Constructor
78
CARTTrainer
(){
79
m_nodeSize
= 1;
80
m_numberOfFolds
= 10;
81
}
82
83
/// \brief From INameable: return the class name.
84
std::string
name
()
const
85
{
return
"CARTTrainer"
; }
86
87
///Train classification
88
SHARK_EXPORT_SYMBOL
void
train
(ModelType& model,
ClassificationDataset
const
& dataset);
89
90
///Train regression
91
SHARK_EXPORT_SYMBOL
void
train
(ModelType& model,
RegressionDataset
const
& dataset);
92
93
///Sets the number of folds used for creation of the trees.
94
void
setNumberOfFolds
(
unsigned
int
folds){
95
m_numberOfFolds
= folds;
96
}
97
protected
:
98
99
///Types frequently used
100
struct
TableEntry
{
101
double
value
;
102
std::size_t
id
;
103
104
bool
operator<
(
TableEntry
const
& v2)
const
{
105
return
value < v2.
value
;
106
}
107
};
108
typedef
std::vector < TableEntry >
AttributeTable
;
109
typedef
std::vector < AttributeTable >
AttributeTables
;
110
111
typedef
ModelType::SplitMatrixType
SplitMatrixType
;
112
113
114
///Number of attributes in the dataset
115
std::size_t
m_inputDimension
;
116
117
///Size of labels
118
std::size_t
m_labelDimension
;
119
120
///Controls the number of samples in the terminal nodes
121
std::size_t
m_nodeSize
;
122
123
///Holds the maximum label. Used in allocating the histograms
124
unsigned
int
m_maxLabel
;
125
126
///Number of folds used to create the tree.
127
unsigned
int
m_numberOfFolds
;
128
129
//Classification functions
130
///Builds a single decision tree from a classification dataset
131
///The method requires the attribute tables,
132
SHARK_EXPORT_SYMBOL
SplitMatrixType
buildTree
(AttributeTables
const
& tables,
ClassificationDataset
const
& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId );
133
134
///Calculates the Gini impurity of a node. The impurity is defined as
135
///1-sum_j p(j|t)^2
136
///i.e the 1 minus the sum of the squared probability of observing class j in node t
137
SHARK_EXPORT_SYMBOL
double
gini
(boost::unordered_map<std::size_t, std::size_t>& countMatrix, std::size_t n);
138
///Creates a histogram from the count matrix.
139
SHARK_EXPORT_SYMBOL
RealVector
hist
(boost::unordered_map<std::size_t, std::size_t> countMatrix);
140
141
///Regression functions
142
SHARK_EXPORT_SYMBOL
SplitMatrixType
buildTree
(AttributeTables
const
& tables,
RegressionDataset
const
& dataset, std::vector<RealVector>
const
& labels, std::size_t nodeId, std::size_t trainSize);
143
///Calculates the total sum of squares
144
SHARK_EXPORT_SYMBOL
double
totalSumOfSquares
(std::vector<RealVector>
const
& labels, std::size_t start, std::size_t length,
const
RealVector& sumLabel);
145
///Calculates the mean of a vector of labels
146
SHARK_EXPORT_SYMBOL
RealVector
mean
(std::vector<RealVector>
const
& labels);
147
148
///Pruning
149
///Prunes decision tree, represented by a split matrix
150
SHARK_EXPORT_SYMBOL
void
pruneMatrix
(SplitMatrixType& splitMatrix);
151
///Prunes a single node, including the child nodes of the decision tree
152
SHARK_EXPORT_SYMBOL
void
pruneNode
(SplitMatrixType& splitMatrix, std::size_t nodeId);
153
///Updates the node variables used in the cost complexity pruning stage
154
SHARK_EXPORT_SYMBOL
void
measureStrenght
(SplitMatrixType& splitMatrix, std::size_t nodeId, std::size_t parentNodeId);
155
156
///Returns the index of the node with node id in splitMatrix.
157
SHARK_EXPORT_SYMBOL
std::size_t
findNode
(SplitMatrixType& splitMatrix, std::size_t nodeId);
158
159
///Attribute table functions
160
///Create the attribute tables used by the SPRINT algorithm
161
SHARK_EXPORT_SYMBOL
AttributeTables
createAttributeTables
(
Data<RealVector>
const
& dataset);
162
///Splits the attribute tables by a attribute index and value. Returns a left and a right attribute table in the variables LAttributeTables and RAttributeTables
163
SHARK_EXPORT_SYMBOL
void
splitAttributeTables
(AttributeTables
const
& tables, std::size_t index, std::size_t valIndex, AttributeTables& LAttributeTables, AttributeTables& RAttributeTables);
164
///Crates count matrices from a classification dataset
165
SHARK_EXPORT_SYMBOL
boost::unordered_map<std::size_t, std::size_t>
createCountMatrix
(
ClassificationDataset
const
& dataset);
166
167
168
};
169
170
171
}
172
#endif
173