RFClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Random Forest Classifier.
6  *
7  *
8  *
9  * \author K. N. Hansen, O.Krause, J. Kremer
10  * \date 2011-2012
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_TREES_RFCLASSIFIER_H
36 #define SHARK_MODELS_TREES_RFCLASSIFIER_H
37 
39 #include <shark/Models/MeanModel.h>
40 
41 namespace shark {
42 
44 typedef std::vector<SplitMatrixType> ForestInfo;
45 
46 ///
47 /// \brief Random Forest Classifier.
48 ///
49 /// \par
50 /// The Random Forest Classifier predicts a class label
51 /// using the Random Forest algorithm as described in<br/>
52 /// Random Forests. Leo Breiman. Machine Learning, 1(45), pages 5-32. Springer, 2001.<br/>
53 ///
54 /// \par
55 /// It is a ensemble learner that uses multiple decision trees built
56 /// using the CART methodology.
57 ///
58 class RFClassifier : public MeanModel<CARTClassifier<RealVector> >
59 {
60 public:
61  /// \brief From INameable: return the class name.
62  std::string name() const
63  { return "RFClassifier"; }
64 
65  // compute the oob error for the forest
67  std::size_t n_trees = numberOfModels();
68  m_OOBerror = 0;
69  for(std::size_t j=0;j!=n_trees;++j){
70  m_OOBerror += m_models[j].OOBerror();
71  }
72  m_OOBerror /= n_trees;
73  }
74 
75  // compute the feature importances for the forest
78  std::size_t n_trees = numberOfModels();
79 
80  for(std::size_t i=0;i!=m_inputDimension;++i){
81  m_featureImportances[i] = 0;
82  for(std::size_t j=0;j!=n_trees;++j){
83  m_featureImportances[i] += m_models[j].featureImportances()[i];
84  }
85  m_featureImportances[i] /= n_trees;
86  }
87  }
88 
89  double const OOBerror() const {
90  return m_OOBerror;
91  }
92 
93  // returns the feature importances
94  RealVector const& featureImportances() const {
95  return m_featureImportances;
96  }
97 
98  //Count how often attributes are used
99  UIntVector countAttributes() const {
100  std::size_t n = m_models.size();
101  if(!n) return UIntVector();
102  UIntVector r = m_models[0].countAttributes();
103  for(std::size_t i=1; i< n; i++ ) {
104  noalias(r) += m_models[i].countAttributes();
105  }
106  return r;
107  }
108 
109  /// Set the dimension of the labels
110  void setLabelDimension(std::size_t in){
111  m_labelDimension = in;
112  }
113 
114  // Set the input dimension
115  void setInputDimension(std::size_t in){
116  m_inputDimension = in;
117  }
118 
119  ForestInfo getForestInfo() const {
120  ForestInfo finfo(m_models.size());
121  for (std::size_t i=0; i<m_models.size(); ++i)
122  finfo[i]=m_models[i].getSplitMatrix();
123  return finfo;
124  }
125 
126  void setForestInfo(ForestInfo const& finfo, std::vector<double> const& weights = std::vector<double>()) {
127  std::size_t n_tree = finfo.size();
128  std::vector<double> we(weights);
129  m_models.resize(n_tree);
130  if (weights.empty()) // set default weights to 1
131  we.resize(n_tree, 1);
132  else if (weights.size() != n_tree)
133  throw SHARKEXCEPTION("Weights must be the same number as trees");
134 
135  for (std::size_t i=0; i<n_tree; ++i){
136  m_models[i]=finfo[i];
137  m_weight.push_back(we[i]);
138  m_weightSum+=we[i];
139  }
140  }
141 
142 protected:
143  // Dimension of label in the regression case, number of classes in the classification case.
144  std::size_t m_labelDimension;
145 
146  // Input dimension
147  std::size_t m_inputDimension;
148 
149  // oob error for the forest
150  double m_OOBerror;
151 
152  // feature importances for the forest
154 
155 };
156 
157 
158 }
159 #endif