RFTrainer.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Random Forest Trainer
6  *
7  *
8  *
9  * \author K. N. Hansen, J. Kremer
10  * \date 2011-2012
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 
36 #ifndef SHARK_ALGORITHMS_TRAINERS_RFTRAINER_H
37 #define SHARK_ALGORITHMS_TRAINERS_RFTRAINER_H
38 
39 #include <shark/Core/DLLSupport.h>
42 
43 #include <boost/unordered_map.hpp>
44 #include <set>
45 
46 namespace shark {
47 /*!
48  * \brief Random Forest
49  *
50  * Random Forest is an ensemble learner, that builds multiple binary decision trees.
51  * The trees are built using a variant of the CART methodology
52  *
53  * The algorithm used to generate each tree based on the SPRINT algorithm, as
54  * shown by J. Shafer et al.
55  *
56  * Typically 100+ trees are built, and classification/regression is done by combining
57  * the results generated by each tree. Typically the a majority vote is used in the
58  * classification case, and the mean is used in the regression case
59  *
60  * Each tree is built based on a random subset of the total dataset. Furthermore
61  * at each split, only a random subset of the attributes are investigated for
62  * the best split
63  *
64  * The node impurity is measured by the Gini criteria in the classification
65  * case, and the total sum of squared errors in the regression case
66  *
67  * After growing a maximum sized tree, the tree is added to the ensemble
68  * without pruning.
69  *
70  * For detailed information about Random Forest, see Random Forest
71  * by L. Breiman et al. 2001.
72  *
73  * For detailed information about the SPRINT algorithm, see
74  * SPRINT: A Scalable Parallel Classifier for Data Mining
75  * by J. Shafer et al.
76  */
77 class RFTrainer
78 : public AbstractTrainer<RFClassifier, unsigned int>
79 , public AbstractTrainer<RFClassifier>,
80  public IParameterizable
81 {
82 
83 public:
84  /// Construct and compute feature importances when training or not
85  SHARK_EXPORT_SYMBOL RFTrainer(bool computeFeatureImportances = false, bool computeOOBerror = false);
86 
87  /// \brief From INameable: return the class name.
88  std::string name() const
89  { return "RFTrainer"; }
90 
91  /// Train a random forest for classification.
92  SHARK_EXPORT_SYMBOL void train(RFClassifier& model, const ClassificationDataset& dataset);
93 
94  /// Train a random forest for regression.
95  SHARK_EXPORT_SYMBOL void train(RFClassifier& model, const RegressionDataset& dataset);
96 
97  /// Set the number of random attributes to investigate at each node.
98  SHARK_EXPORT_SYMBOL void setMTry(std::size_t mtry);
99 
100  /// Set the number of trees to grow.
101  SHARK_EXPORT_SYMBOL void setNTrees(std::size_t nTrees);
102 
103  /// Controls when a node is considered pure. If set to 1, a node is pure
104  /// when it only consists of a single node.
105  SHARK_EXPORT_SYMBOL void setNodeSize(std::size_t nTrees);
106 
107  /// Set the fraction of the original training dataset to use as the
108  /// out of bag sample. The default value is 0.66.
109  SHARK_EXPORT_SYMBOL void setOOBratio(double ratio);
110 
111  /// Return the parameter vector.
112  RealVector parameterVector() const
113  {
114  RealVector ret(1); // number of trees
115  init(ret) << (double)m_B;
116  return ret;
117  }
118 
119  /// Set the parameter vector.
120  void setParameterVector(RealVector const& newParameters)
121  {
122  SHARK_ASSERT(newParameters.size() == numberOfParameters());
123  setNTrees((size_t) newParameters[0]);
124  }
125 
126 protected:
127  struct RFAttribute {
128  double value;
129  std::size_t id;
130  };
131 
132  /// attribute table
133  typedef std::vector < RFAttribute > AttributeTable;
134  /// collecting of attribute tables
135  typedef std::vector < AttributeTable > AttributeTables;
136 
137  /// Create attribute tables from a data set, and in the process create a count matrix (cAbove).
138  /// A dataset with m features results in m attribute tables.
139  /// [attribute | class/value | row id ]
140  SHARK_EXPORT_SYMBOL void createAttributeTables(Data<RealVector> const& dataset, AttributeTables& tables);
141 
142  /// Create a count matrix as used in the classification case.
143  SHARK_EXPORT_SYMBOL void createCountMatrix(const ClassificationDataset& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove);
144 
145  // Split attribute tables into left and right parts.
146  SHARK_EXPORT_SYMBOL void splitAttributeTables(const AttributeTables& tables, std::size_t index, std::size_t valIndex, AttributeTables& LAttributeTables, AttributeTables& RAttributeTables);
147 
148  /// Build a decision tree for classification
149  SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::SplitMatrixType buildTree(AttributeTables& tables, const ClassificationDataset& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId);
150 
151  /// Builds a decision tree for regression
152  SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::SplitMatrixType buildTree(AttributeTables& tables, const RegressionDataset& dataset, const std::vector<RealVector>& labels, std::size_t nodeId);
153 
154  /// comparison function for sorting an attributeTable
155  SHARK_EXPORT_SYMBOL static bool tableSort(const RFAttribute& v1, const RFAttribute& v2);
156 
157  /// Generate a histogram from the count matrix.
158  SHARK_EXPORT_SYMBOL RealVector hist(boost::unordered_map<std::size_t, std::size_t> countMatrix);
159 
160  /// Average label over a vector.
161  SHARK_EXPORT_SYMBOL RealVector average(const std::vector<RealVector>& labels);
162 
163  /// Calculate the Gini impurity of the countMatrix
164  SHARK_EXPORT_SYMBOL double gini(boost::unordered_map<std::size_t, std::size_t>& countMatrix, std::size_t n);
165 
166  /// Total Sum Of Squares
167  SHARK_EXPORT_SYMBOL double totalSumOfSquares(std::vector<RealVector>& labels, std::size_t from, std::size_t to, const RealVector& sumLabel);
168 
169  /// Generate random table indices.
170  SHARK_EXPORT_SYMBOL void generateRandomTableIndicies(std::set<std::size_t>& tableIndicies);
171 
172  /// Reset the training to its default parameters.
174 
175  /// Number of attributes in the dataset
176  std::size_t m_inputDimension;
177 
178  /// size of labels
179  std::size_t m_labelDimension;
180 
181  /// maximum size of the histogram;
182  /// classification case: maximum number of classes
183  unsigned int m_maxLabel;
184 
185  /// number of attributes to randomly test at each inner node
186  std::size_t m_try;
187 
188  /// number of trees in the forest
189  std::size_t m_B;
190 
191  /// number of samples in the terminal nodes
192  std::size_t m_nodeSize;
193 
194  /// fraction of the data set used for growing trees
195  /// 0 < m_OOBratio < 1
196  double m_OOBratio;
197 
198  /// true if the trainer is used for regression, false otherwise.
200 
201  // true if the feature importances should be computed
203 
204  // true if OOB error should be computed
206 };
207 }
208 #endif