CARTTrainer.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief CART
6  *
7  *
8  *
9  * \author K. N. Hansen
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 
36 #ifndef SHARK_ALGORITHMS_TRAINERS_CARTTRAINER_H
37 #define SHARK_ALGORITHMS_TRAINERS_CARTTRAINER_H
38 
39 #include <shark/Core/DLLSupport.h>
42 #include <boost/unordered_map.hpp>
43 
44 namespace shark {
45 /*!
46  * \brief Classification And Regression Trees CART
47  *
48  * CART is a decision tree algorithm, that builds a binary decision tree
49  * The decision tree is built by partitioning a dataset recursively
50  *
51  * The partitioning is done, so that the partition chosen at a single
52  * node, is the partition the produces the largest decrease in node
53  * impurity.
54  *
55  * The node impurity is measured by the Gini criteria in the classification
56  * case, and the total sum of squares error in the regression case
57  *
58  * The tree is grown, until all leafs are pure. A node is considered pure
59  * when it only consist of identical cases in the classification case
60  * and identical or single values in the regression case
61  *
62  * After the maximum sized tree is grown, the tree is pruned back from the leafs upward.
63  * The pruning is done by cost complexity pruning, as described by L. Breiman
64  *
65  * The algorithm used is based on the SPRINT algorithm, as shown by J. Shafer et al.
66  *
67  * For more detailed information about CART, see \e Classification \e And \e Regression
68  * \e Trees written by L. Breiman et al. 1984.
69  */
71 : public AbstractTrainer<CARTClassifier<RealVector>, unsigned int>
72 , public AbstractTrainer<CARTClassifier<RealVector>, RealVector >
73 {
74 public:
76 
77  /// Constructor
79  m_nodeSize = 1;
80  m_numberOfFolds = 10;
81  }
82 
83  /// \brief From INameable: return the class name.
84  std::string name() const
85  { return "CARTTrainer"; }
86 
87  ///Train classification
88  SHARK_EXPORT_SYMBOL void train(ModelType& model, ClassificationDataset const& dataset);
89 
90  ///Train regression
91  SHARK_EXPORT_SYMBOL void train(ModelType& model, RegressionDataset const& dataset);
92 
93  ///Sets the number of folds used for creation of the trees.
94  void setNumberOfFolds(unsigned int folds){
95  m_numberOfFolds = folds;
96  }
97 protected:
98 
99  ///Types frequently used
100  struct TableEntry{
101  double value;
102  std::size_t id;
103 
104  bool operator<( TableEntry const& v2)const {
105  return value < v2.value;
106  }
107  };
108  typedef std::vector < TableEntry > AttributeTable;
109  typedef std::vector < AttributeTable > AttributeTables;
110 
112 
113 
114  ///Number of attributes in the dataset
115  std::size_t m_inputDimension;
116 
117  ///Size of labels
118  std::size_t m_labelDimension;
119 
120  ///Controls the number of samples in the terminal nodes
121  std::size_t m_nodeSize;
122 
123  ///Holds the maximum label. Used in allocating the histograms
124  unsigned int m_maxLabel;
125 
126  ///Number of folds used to create the tree.
127  unsigned int m_numberOfFolds;
128 
129  //Classification functions
130  ///Builds a single decision tree from a classification dataset
131  ///The method requires the attribute tables,
132  SHARK_EXPORT_SYMBOL SplitMatrixType buildTree(AttributeTables const& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId );
133 
134  ///Calculates the Gini impurity of a node. The impurity is defined as
135  ///1-sum_j p(j|t)^2
136  ///i.e the 1 minus the sum of the squared probability of observing class j in node t
137  SHARK_EXPORT_SYMBOL double gini(boost::unordered_map<std::size_t, std::size_t>& countMatrix, std::size_t n);
138  ///Creates a histogram from the count matrix.
139  SHARK_EXPORT_SYMBOL RealVector hist(boost::unordered_map<std::size_t, std::size_t> countMatrix);
140 
141  ///Regression functions
142  SHARK_EXPORT_SYMBOL SplitMatrixType buildTree(AttributeTables const& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId, std::size_t trainSize);
143  ///Calculates the total sum of squares
144  SHARK_EXPORT_SYMBOL double totalSumOfSquares(std::vector<RealVector> const& labels, std::size_t start, std::size_t length, const RealVector& sumLabel);
145  ///Calculates the mean of a vector of labels
146  SHARK_EXPORT_SYMBOL RealVector mean(std::vector<RealVector> const& labels);
147 
148  ///Pruning
149  ///Prunes decision tree, represented by a split matrix
150  SHARK_EXPORT_SYMBOL void pruneMatrix(SplitMatrixType& splitMatrix);
151  ///Prunes a single node, including the child nodes of the decision tree
152  SHARK_EXPORT_SYMBOL void pruneNode(SplitMatrixType& splitMatrix, std::size_t nodeId);
153  ///Updates the node variables used in the cost complexity pruning stage
154  SHARK_EXPORT_SYMBOL void measureStrenght(SplitMatrixType& splitMatrix, std::size_t nodeId, std::size_t parentNodeId);
155 
156  ///Returns the index of the node with node id in splitMatrix.
157  SHARK_EXPORT_SYMBOL std::size_t findNode(SplitMatrixType& splitMatrix, std::size_t nodeId);
158 
159  ///Attribute table functions
160  ///Create the attribute tables used by the SPRINT algorithm
161  SHARK_EXPORT_SYMBOL AttributeTables createAttributeTables(Data<RealVector> const& dataset);
162  ///Splits the attribute tables by a attribute index and value. Returns a left and a right attribute table in the variables LAttributeTables and RAttributeTables
163  SHARK_EXPORT_SYMBOL void splitAttributeTables(AttributeTables const& tables, std::size_t index, std::size_t valIndex, AttributeTables& LAttributeTables, AttributeTables& RAttributeTables);
164  ///Crates count matrices from a classification dataset
165  SHARK_EXPORT_SYMBOL boost::unordered_map<std::size_t, std::size_t> createCountMatrix(ClassificationDataset const& dataset);
166 
167 
168 };
169 
170 
171 }
172 #endif
173