CARTClassifier.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Cart Classifier
6  *
7  *
8  *
9  * \author K. N. Hansen, J. Kremer
10  * \date 2012
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_MODELS_TREES_CARTCLASSIFIER_H
36 #define SHARK_MODELS_TREES_CARTCLASSIFIER_H
37 
38 
42 #include <shark/Data/Dataset.h>
43 
44 namespace shark {
45 
46 
47 ///
48 /// \brief CART Classifier.
49 ///
50 /// \par
51 /// The CARTClassifier predicts a class label
52 /// using the CART algorithm.
53 ///
54 /// \par
55 /// It is a decision tree algorithm.
56 ///
57 template<class LabelType>
58 class CARTClassifier : public AbstractModel<RealVector,LabelType>
59 {
60 private:
62 public:
65 // Information about a single split. misclassProp, r and g are variables used in the cost complexity step
66  struct SplitInfo{
67  std::size_t nodeId;
68  std::size_t attributeIndex;
70  std::size_t leftNodeId;
71  std::size_t rightNodeId;
72  LabelType label;
73  double misclassProp;//TODO: remove this
74  std::size_t r;//TODO: remove this
75  double g;//TODO: remove this
76 
77  template<class Archive>
78  void serialize(Archive & ar, const unsigned int version){
79  ar & nodeId;
80  ar & attributeIndex;
81  ar & attributeValue;
82  ar & leftNodeId;
83  ar & rightNodeId;
84  ar & label;
85  ar & misclassProp;
86  ar & r;
87  ar & g;
88  }
89  };
90 
91  /// Vector of structs that contains the splitting information and the labels.
92  /// The class label is a normalized histogram in the classification case.
93  /// In the regression case, the label is the regression value.
94  typedef std::vector<SplitInfo> SplitMatrixType;
95 
96  /// Constructor
98  {}
99 
100  /// Constructor taking the splitMatrix as argument
101  CARTClassifier(SplitMatrixType const& splitMatrix)
102  {
103  m_splitMatrix=splitMatrix;
104  }
105 
106  /// Constructor taking the splitMatrix as argument and optimize it if requested
107  CARTClassifier(SplitMatrixType const& splitMatrix, bool optimize)
108  {
109  if (optimize)
110  setSplitMatrix(splitMatrix);
111  else
112  m_splitMatrix=splitMatrix;
113  }
114 
115  /// Constructor taking the splitMatrix as argument as well as maximum number of attributes
116  CARTClassifier(SplitMatrixType const& splitMatrix, std::size_t d)
117  {
118  setSplitMatrix(splitMatrix);
119  m_inputDimension = d;
120  }
121 
122  /// \brief From INameable: return the class name.
123  std::string name() const
124  { return "CARTClassifier"; }
125 
126  boost::shared_ptr<State> createState()const{
127  return boost::shared_ptr<State>(new EmptyState());
128  }
129 
130  using base_type::eval;
131  /// \brief Evaluate the Tree on a batch of patterns
132  void eval(const BatchInputType& patterns, BatchOutputType& outputs)const{
133  std::size_t numPatterns = shark::size(patterns);
134  //evaluate the first pattern alone and create the batch output from that
135  LabelType const& firstResult = evalPattern(row(patterns,0));
136  outputs = Batch<LabelType>::createBatch(firstResult,numPatterns);
137  get(outputs,0) = firstResult;
138 
139  //evaluate the rest
140  for(std::size_t i = 0; i != numPatterns; ++i){
141  get(outputs,i) = evalPattern(row(patterns,i));
142  }
143  }
144 
145  void eval(const BatchInputType& patterns, BatchOutputType& outputs, State& state)const{
146  eval(patterns,outputs);
147  }
148  /// \brief Evaluate the Tree on a single pattern
149  void eval(RealVector const & pattern, LabelType& output){
150  output = evalPattern(pattern);
151  }
152 
153  /// Set the model split matrix.
154  void setSplitMatrix(SplitMatrixType const& splitMatrix){
155  m_splitMatrix = splitMatrix;
157  }
158 
159  /// Get the model split matrix.
160  SplitMatrixType getSplitMatrix() const {
161  return m_splitMatrix;
162  }
163 
164  /// \brief The model does not have any parameters.
165  std::size_t numberOfParameters()const{
166  return 0;
167  }
168 
169  /// \brief The model does not have any parameters.
170  RealVector parameterVector() const {
171  return RealVector();
172  }
173 
174  /// \brief The model does not have any parameters.
175  void setParameterVector(const RealVector& param) {
176  SHARK_ASSERT(param.size() == 0);
177  }
178 
179  /// from ISerializable, reads a model from an archive
180  void read(InArchive& archive){
181  archive >> m_splitMatrix;
182  }
183 
184  /// from ISerializable, writes a model to an archive
185  void write(OutArchive& archive) const {
186  archive << m_splitMatrix;
187  }
188 
189 
190  //Count how often attributes are used
191  UIntVector countAttributes() const {
193  UIntVector r(m_inputDimension, 0);
194  typename SplitMatrixType::const_iterator it;
195  for(it = m_splitMatrix.begin(); it != m_splitMatrix.end(); ++it) {
196  //std::cout << "NodeId: " <<it->leftNodeId << std::endl;
197  if(it->leftNodeId != 0) { // not a label
198  r(it->attributeIndex)++;
199  }
200  }
201  return r;
202  }
203 
204  ///Return input dimension
205  std::size_t inputSize() const {
206  return m_inputDimension;
207  }
208 
209  //Set input dimension
210  void setInputDimension(std::size_t d) {
211  m_inputDimension = d;
212  }
213 
214  /// Compute oob error, given an oob dataset (Classification)
216  // define loss
218 
219  // predict oob data
220  Data<RealVector> predOOB = (*this)(dataOOB.inputs());
221 
222  // count average number of oob misclassifications
223  m_OOBerror = lossOOB.eval(dataOOB.labels(), predOOB);
224  }
225 
226  /// Compute oob error, given an oob dataset (Regression)
227  void computeOOBerror(const RegressionDataset& dataOOB){
228  // define loss
230 
231  // predict oob data
232  Data<RealVector> predOOB = (*this)(dataOOB.inputs());
233 
234  // Compute mean squared error
235  m_OOBerror = lossOOB.eval(dataOOB.labels(), predOOB);
236  }
237 
238  /// Return OOB error
239  double OOBerror() const {
240  return m_OOBerror;
241  }
242 
243  /// Return feature importances
244  RealVector const& featureImportances() const {
245  return m_featureImportances;
246  }
247 
248  /// Compute feature importances, given an oob dataset (Classification)
251 
252  // define loss
254 
255  // compute oob error
256  computeOOBerror(dataOOB);
257 
258  // count average number of correct oob predictions
259  double accuracyOOB = 1. - m_OOBerror;
260 
261  // go through all dimensions, permute each dimension across all elements and train the tree on it
262  for(std::size_t i=0;i!=m_inputDimension;++i) {
263  // create permuted dataset by copying
264  ClassificationDataset pDataOOB(dataOOB);
265  pDataOOB.makeIndependent();
266 
267  // permute current dimension
268  RealVector v = getColumn(pDataOOB.inputs(), i);
269  std::random_shuffle(v.begin(), v.end());
270  setColumn(pDataOOB.inputs(), i, v);
271 
272  // evaluate the data set for which one feature dimension was permuted with this tree
273  Data<RealVector> pPredOOB = (*this)(pDataOOB.inputs());
274 
275  // count the number of correct predictions
276  double accuracyPermutedOOB = 1. - lossOOB.eval(pDataOOB.labels(),pPredOOB);
277 
278  // store importance
279  m_featureImportances[i] = std::fabs(accuracyOOB - accuracyPermutedOOB);
280  }
281  }
282 
283  /// Compute feature importances, given an oob dataset (Regression)
286 
287  // define loss
289 
290  // compute oob error
291  computeOOBerror(dataOOB);
292 
293  // mean squared error for oob sample
294  double mseOOB = m_OOBerror;
295 
296  // go through all dimensions, permute each dimension across all elements and train the tree on it
297  for(std::size_t i=0;i!=m_inputDimension;++i) {
298  // create permuted dataset by copying
299  RegressionDataset pDataOOB(dataOOB);
300  pDataOOB.makeIndependent();
301 
302  // permute current dimension
303  RealVector v = getColumn(pDataOOB.inputs(), i);
304  std::random_shuffle(v.begin(), v.end());
305  setColumn(pDataOOB.inputs(), i, v);
306 
307  // evaluate the data set for which one feature dimension was permuted with this tree
308  Data<RealVector> pPredOOB = (*this)(pDataOOB.inputs());
309 
310  // mean squared error of permuted oob sample
311  double msePermutedOOB = lossOOB.eval(pDataOOB.labels(),pPredOOB);
312 
313  // store importance
314  m_featureImportances[i] = std::fabs(msePermutedOOB - mseOOB);
315  }
316  }
317 
318 protected:
319  /// split matrix of the model
320  SplitMatrixType m_splitMatrix;
321 
322  /// \brief Finds the index of the node with a certain nodeID in an unoptimized split matrix.
323  std::size_t findNode(std::size_t nodeId)const{
324  std::size_t index = 0;
325  for(; nodeId != m_splitMatrix[index].nodeId; ++index);
326  return index;
327  }
328 
329  /// Optimize a split matrix, so constant lookup can be used.
330  /// The optimization is done by changing the index of the children
331  /// to use indices instead of node ID.
332  /// Furthermore, the node IDs are converted to index numbers.
333  void optimizeSplitMatrix(SplitMatrixType& splitMatrix)const{
334  for(std::size_t i = 0; i < splitMatrix.size(); i++){
335  splitMatrix[i].leftNodeId = findNode(splitMatrix[i].leftNodeId);
336  splitMatrix[i].rightNodeId = findNode(splitMatrix[i].rightNodeId);
337  }
338  for(std::size_t i = 0; i < splitMatrix.size(); i++){
339  splitMatrix[i].nodeId = i;
340  }
341  }
342 
343  /// Evaluate the CART tree on a single sample
344  template<class Vector>
345  LabelType const& evalPattern(Vector const& pattern)const{
346  std::size_t nodeId = 0;
347  while(m_splitMatrix[nodeId].leftNodeId != 0){
348  if(pattern[m_splitMatrix[nodeId].attributeIndex]<=m_splitMatrix[nodeId].attributeValue){
349  //Branch on left node
350  nodeId = m_splitMatrix[nodeId].leftNodeId;
351  }else{
352  //Branch on right node
353  nodeId = m_splitMatrix[nodeId].rightNodeId;
354  }
355  }
356  return m_splitMatrix[nodeId].label;
357  }
358 
359 
360  ///Number of attributes (set by trainer)
361  std::size_t m_inputDimension;
362 
363  // feature importances
365 
366  // oob error
367  double m_OOBerror;
368 };
369 
370 
371 }
372 #endif