[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_visitors.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 #ifndef RF_VISITORS_HXX
36 #define RF_VISITORS_HXX
37 
38 #ifdef HasHDF5
39 # include "vigra/hdf5impex.hxx"
40 #else
41 # include "vigra/impex.hxx"
42 # include "vigra/multi_array.hxx"
43 # include "vigra/multi_impex.hxx"
44 # include "vigra/inspectimage.hxx"
45 #endif // HasHDF5
46 #include <vigra/windows.h>
47 #include <iostream>
48 #include <iomanip>
49 
50 #include <vigra/multi_pointoperators.hxx>
51 #include <vigra/timing.hxx>
52 
53 namespace vigra
54 {
55 namespace rf
56 {
57 /** \addtogroup MachineLearning Machine Learning
58 **/
59 //@{
60 
61 /**
62  This namespace contains all classes and methods related to extracting information during
63  learning of the random forest. All Visitors share the same interface defined in
64  visitors::VisitorBase. The member methods are invoked at certain points of the main code in
65  the order they were supplied.
66 
67  For the Random Forest the Visitor concept is implemented as a statically linked list
68  (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The
69  VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
70 
71  To simplify usage create_visitor() factory methods are supplied.
72  Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
73  It is possible to supply more than one visitor. They will then be invoked in serial order.
74 
75  The calculated information are stored as public data members of the class. - see documentation
76  of the individual visitors
77 
78  While creating a new visitor the new class should therefore publicly inherit from this class
79  (i.e.: see visitors::OOB_Error).
80 
81  \code
82 
83  typedef xxx feature_t \\ replace xxx with whichever type
84  typedef yyy label_t \\ meme chose.
85  MultiArrayView<2, feature_t> f = get_some_features();
86  MultiArrayView<2, label_t> l = get_some_labels();
87  RandomForest<> rf()
88 
89  //calculate OOB Error
90  visitors::OOB_Error oob_v;
91  //calculate Variable Importance
92  visitors::VariableImportanceVisitor varimp_v;
93 
94  double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
95  //the data can be found in the attributes of oob_v and varimp_v now
96 
97  \endcode
98 */
99 namespace visitors
100 {
101 
102 
103 /** Base Class from which all Visitors derive. Can be used as a template to create new
104  * Visitors.
105  */
107 {
108  public:
109  bool active_;
110  bool is_active()
111  {
112  return active_;
113  }
114 
115  bool has_value()
116  {
117  return false;
118  }
119 
120  VisitorBase()
121  : active_(true)
122  {}
123 
124  void deactivate()
125  {
126  active_ = false;
127  }
128  void activate()
129  {
130  active_ = true;
131  }
132 
133  /** do something after the the Split has decided how to process the Region
134  * (Stack entry)
135  *
136  * \param tree reference to the tree that is currently being learned
137  * \param split reference to the split object
138  * \param parent current stack entry which was used to decide the split
139  * \param leftChild left stack entry that will be pushed
140  * \param rightChild
141  * right stack entry that will be pushed.
142  * \param features features matrix
143  * \param labels label matrix
144  * \sa RF_Traits::StackEntry_t
145  */
146  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
147  void visit_after_split( Tree & tree,
148  Split & split,
149  Region & parent,
150  Region & leftChild,
151  Region & rightChild,
152  Feature_t & features,
153  Label_t & labels)
154  {}
155 
156  /** do something after each tree has been learned
157  *
158  * \param rf reference to the random forest object that called this
159  * visitor
160  * \param pr reference to the preprocessor that processed the input
161  * \param sm reference to the sampler object
162  * \param st reference to the first stack entry
163  * \param index index of current tree
164  */
165  template<class RF, class PR, class SM, class ST>
166  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
167  {}
168 
169  /** do something after all trees have been learned
170  *
171  * \param rf reference to the random forest object that called this
172  * visitor
173  * \param pr reference to the preprocessor that processed the input
174  */
175  template<class RF, class PR>
176  void visit_at_end(RF const & rf, PR const & pr)
177  {}
178 
179  /** do something before learning starts
180  *
181  * \param rf reference to the random forest object that called this
182  * visitor
183  * \param pr reference to the Processor class used.
184  */
185  template<class RF, class PR>
186  void visit_at_beginning(RF const & rf, PR const & pr)
187  {}
188  /** do some thing while traversing tree after it has been learned
189  * (external nodes)
190  *
191  * \param tr reference to the tree object that called this visitor
192  * \param index index in the topology_ array we currently are at
193  * \param node_t type of node we have (will be e_.... - )
194  * \param features feature matrix
195  * \sa NodeTags;
196  *
197  * you can create the node by using a switch on node_tag and using the
198  * corresponding Node objects. Or - if you do not care about the type
199  * use the NodeBase class.
200  */
201  template<class TR, class IntT, class TopT,class Feat>
202  void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features)
203  {}
204 
205  /** do something when visiting a internal node after it has been learned
206  *
207  * \sa visit_external_node
208  */
209  template<class TR, class IntT, class TopT,class Feat>
210  void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
211  {}
212 
213  /** return a double value. The value of the first
214  * visitor encountered that has a return value is returned with the
215  * RandomForest::learn() method - or -1.0 if no return value visitor
216  * existed. This functionality basically only exists so that the
217  * OOB - visitor can return the oob error rate like in the old version
218  * of the random forest.
219  */
220  double return_val()
221  {
222  return -1.0;
223  }
224 };
225 
226 
227 /** Last Visitor that should be called to stop the recursion.
228  */
230 {
231  public:
232  bool has_value()
233  {
234  return true;
235  }
236  double return_val()
237  {
238  return -1.0;
239  }
240 };
241 namespace detail
242 {
243 /** Container elements of the statically linked Visitor list.
244  *
245  * use the create_visitor() factory functions to create visitors up to size 10;
246  *
247  */
248 template <class Visitor, class Next = StopVisiting>
250 {
251  public:
252 
253  StopVisiting stop_;
254  Next next_;
255  Visitor & visitor_;
256  VisitorNode(Visitor & visitor, Next & next)
257  :
258  next_(next), visitor_(visitor)
259  {}
260 
261  VisitorNode(Visitor & visitor)
262  :
263  next_(stop_), visitor_(visitor)
264  {}
265 
266  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
267  void visit_after_split( Tree & tree,
268  Split & split,
269  Region & parent,
270  Region & leftChild,
271  Region & rightChild,
272  Feature_t & features,
273  Label_t & labels)
274  {
275  if(visitor_.is_active())
276  visitor_.visit_after_split(tree, split,
277  parent, leftChild, rightChild,
278  features, labels);
279  next_.visit_after_split(tree, split, parent, leftChild, rightChild,
280  features, labels);
281  }
282 
283  template<class RF, class PR, class SM, class ST>
284  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
285  {
286  if(visitor_.is_active())
287  visitor_.visit_after_tree(rf, pr, sm, st, index);
288  next_.visit_after_tree(rf, pr, sm, st, index);
289  }
290 
291  template<class RF, class PR>
292  void visit_at_beginning(RF & rf, PR & pr)
293  {
294  if(visitor_.is_active())
295  visitor_.visit_at_beginning(rf, pr);
296  next_.visit_at_beginning(rf, pr);
297  }
298  template<class RF, class PR>
299  void visit_at_end(RF & rf, PR & pr)
300  {
301  if(visitor_.is_active())
302  visitor_.visit_at_end(rf, pr);
303  next_.visit_at_end(rf, pr);
304  }
305 
306  template<class TR, class IntT, class TopT,class Feat>
307  void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
308  {
309  if(visitor_.is_active())
310  visitor_.visit_external_node(tr, index, node_t,features);
311  next_.visit_external_node(tr, index, node_t,features);
312  }
313  template<class TR, class IntT, class TopT,class Feat>
314  void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
315  {
316  if(visitor_.is_active())
317  visitor_.visit_internal_node(tr, index, node_t,features);
318  next_.visit_internal_node(tr, index, node_t,features);
319  }
320 
321  double return_val()
322  {
323  if(visitor_.is_active() && visitor_.has_value())
324  return visitor_.return_val();
325  return next_.return_val();
326  }
327 };
328 
329 } //namespace detail
330 
331 //////////////////////////////////////////////////////////////////////////////
332 // Visitor Factory function up to 10 visitors //
333 //////////////////////////////////////////////////////////////////////////////
334 
335 /** factory method to to be used with RandomForest::learn()
336  */
337 template<class A>
340 {
341  typedef detail::VisitorNode<A> _0_t;
342  _0_t _0(a);
343  return _0;
344 }
345 
346 
347 /** factory method to to be used with RandomForest::learn()
348  */
349 template<class A, class B>
350 detail::VisitorNode<A, detail::VisitorNode<B> >
351 create_visitor(A & a, B & b)
352 {
353  typedef detail::VisitorNode<B> _1_t;
354  _1_t _1(b);
355  typedef detail::VisitorNode<A, _1_t> _0_t;
356  _0_t _0(a, _1);
357  return _0;
358 }
359 
360 
361 /** factory method to to be used with RandomForest::learn()
362  */
363 template<class A, class B, class C>
364 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
365 create_visitor(A & a, B & b, C & c)
366 {
367  typedef detail::VisitorNode<C> _2_t;
368  _2_t _2(c);
369  typedef detail::VisitorNode<B, _2_t> _1_t;
370  _1_t _1(b, _2);
371  typedef detail::VisitorNode<A, _1_t> _0_t;
372  _0_t _0(a, _1);
373  return _0;
374 }
375 
376 
377 /** factory method to to be used with RandomForest::learn()
378  */
379 template<class A, class B, class C, class D>
380 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
381  detail::VisitorNode<D> > > >
382 create_visitor(A & a, B & b, C & c, D & d)
383 {
384  typedef detail::VisitorNode<D> _3_t;
385  _3_t _3(d);
386  typedef detail::VisitorNode<C, _3_t> _2_t;
387  _2_t _2(c, _3);
388  typedef detail::VisitorNode<B, _2_t> _1_t;
389  _1_t _1(b, _2);
390  typedef detail::VisitorNode<A, _1_t> _0_t;
391  _0_t _0(a, _1);
392  return _0;
393 }
394 
395 
396 /** factory method to to be used with RandomForest::learn()
397  */
398 template<class A, class B, class C, class D, class E>
399 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
400  detail::VisitorNode<D, detail::VisitorNode<E> > > > >
401 create_visitor(A & a, B & b, C & c,
402  D & d, E & e)
403 {
404  typedef detail::VisitorNode<E> _4_t;
405  _4_t _4(e);
406  typedef detail::VisitorNode<D, _4_t> _3_t;
407  _3_t _3(d, _4);
408  typedef detail::VisitorNode<C, _3_t> _2_t;
409  _2_t _2(c, _3);
410  typedef detail::VisitorNode<B, _2_t> _1_t;
411  _1_t _1(b, _2);
412  typedef detail::VisitorNode<A, _1_t> _0_t;
413  _0_t _0(a, _1);
414  return _0;
415 }
416 
417 
418 /** factory method to to be used with RandomForest::learn()
419  */
420 template<class A, class B, class C, class D, class E,
421  class F>
422 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
423  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
424 create_visitor(A & a, B & b, C & c,
425  D & d, E & e, F & f)
426 {
427  typedef detail::VisitorNode<F> _5_t;
428  _5_t _5(f);
429  typedef detail::VisitorNode<E, _5_t> _4_t;
430  _4_t _4(e, _5);
431  typedef detail::VisitorNode<D, _4_t> _3_t;
432  _3_t _3(d, _4);
433  typedef detail::VisitorNode<C, _3_t> _2_t;
434  _2_t _2(c, _3);
435  typedef detail::VisitorNode<B, _2_t> _1_t;
436  _1_t _1(b, _2);
437  typedef detail::VisitorNode<A, _1_t> _0_t;
438  _0_t _0(a, _1);
439  return _0;
440 }
441 
442 
443 /** factory method to to be used with RandomForest::learn()
444  */
445 template<class A, class B, class C, class D, class E,
446  class F, class G>
447 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
448  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
449  detail::VisitorNode<G> > > > > > >
450 create_visitor(A & a, B & b, C & c,
451  D & d, E & e, F & f, G & g)
452 {
453  typedef detail::VisitorNode<G> _6_t;
454  _6_t _6(g);
455  typedef detail::VisitorNode<F, _6_t> _5_t;
456  _5_t _5(f, _6);
457  typedef detail::VisitorNode<E, _5_t> _4_t;
458  _4_t _4(e, _5);
459  typedef detail::VisitorNode<D, _4_t> _3_t;
460  _3_t _3(d, _4);
461  typedef detail::VisitorNode<C, _3_t> _2_t;
462  _2_t _2(c, _3);
463  typedef detail::VisitorNode<B, _2_t> _1_t;
464  _1_t _1(b, _2);
465  typedef detail::VisitorNode<A, _1_t> _0_t;
466  _0_t _0(a, _1);
467  return _0;
468 }
469 
470 
471 /** factory method to to be used with RandomForest::learn()
472  */
473 template<class A, class B, class C, class D, class E,
474  class F, class G, class H>
475 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
476  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
477  detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
478 create_visitor(A & a, B & b, C & c,
479  D & d, E & e, F & f,
480  G & g, H & h)
481 {
482  typedef detail::VisitorNode<H> _7_t;
483  _7_t _7(h);
484  typedef detail::VisitorNode<G, _7_t> _6_t;
485  _6_t _6(g, _7);
486  typedef detail::VisitorNode<F, _6_t> _5_t;
487  _5_t _5(f, _6);
488  typedef detail::VisitorNode<E, _5_t> _4_t;
489  _4_t _4(e, _5);
490  typedef detail::VisitorNode<D, _4_t> _3_t;
491  _3_t _3(d, _4);
492  typedef detail::VisitorNode<C, _3_t> _2_t;
493  _2_t _2(c, _3);
494  typedef detail::VisitorNode<B, _2_t> _1_t;
495  _1_t _1(b, _2);
496  typedef detail::VisitorNode<A, _1_t> _0_t;
497  _0_t _0(a, _1);
498  return _0;
499 }
500 
501 
502 /** factory method to to be used with RandomForest::learn()
503  */
504 template<class A, class B, class C, class D, class E,
505  class F, class G, class H, class I>
506 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
507  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
508  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
509 create_visitor(A & a, B & b, C & c,
510  D & d, E & e, F & f,
511  G & g, H & h, I & i)
512 {
513  typedef detail::VisitorNode<I> _8_t;
514  _8_t _8(i);
515  typedef detail::VisitorNode<H, _8_t> _7_t;
516  _7_t _7(h, _8);
517  typedef detail::VisitorNode<G, _7_t> _6_t;
518  _6_t _6(g, _7);
519  typedef detail::VisitorNode<F, _6_t> _5_t;
520  _5_t _5(f, _6);
521  typedef detail::VisitorNode<E, _5_t> _4_t;
522  _4_t _4(e, _5);
523  typedef detail::VisitorNode<D, _4_t> _3_t;
524  _3_t _3(d, _4);
525  typedef detail::VisitorNode<C, _3_t> _2_t;
526  _2_t _2(c, _3);
527  typedef detail::VisitorNode<B, _2_t> _1_t;
528  _1_t _1(b, _2);
529  typedef detail::VisitorNode<A, _1_t> _0_t;
530  _0_t _0(a, _1);
531  return _0;
532 }
533 
534 /** factory method to to be used with RandomForest::learn()
535  */
536 template<class A, class B, class C, class D, class E,
537  class F, class G, class H, class I, class J>
538 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
539  detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
540  detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
541  detail::VisitorNode<J> > > > > > > > > >
542 create_visitor(A & a, B & b, C & c,
543  D & d, E & e, F & f,
544  G & g, H & h, I & i,
545  J & j)
546 {
547  typedef detail::VisitorNode<J> _9_t;
548  _9_t _9(j);
549  typedef detail::VisitorNode<I, _9_t> _8_t;
550  _8_t _8(i, _9);
551  typedef detail::VisitorNode<H, _8_t> _7_t;
552  _7_t _7(h, _8);
553  typedef detail::VisitorNode<G, _7_t> _6_t;
554  _6_t _6(g, _7);
555  typedef detail::VisitorNode<F, _6_t> _5_t;
556  _5_t _5(f, _6);
557  typedef detail::VisitorNode<E, _5_t> _4_t;
558  _4_t _4(e, _5);
559  typedef detail::VisitorNode<D, _4_t> _3_t;
560  _3_t _3(d, _4);
561  typedef detail::VisitorNode<C, _3_t> _2_t;
562  _2_t _2(c, _3);
563  typedef detail::VisitorNode<B, _2_t> _1_t;
564  _1_t _1(b, _2);
565  typedef detail::VisitorNode<A, _1_t> _0_t;
566  _0_t _0(a, _1);
567  return _0;
568 }
569 
570 //////////////////////////////////////////////////////////////////////////////
571 // Visitors of communal interest. //
572 //////////////////////////////////////////////////////////////////////////////
573 
574 
575 /** Visitor to gain information, later needed for online learning.
576  */
577 
579 {
580 public:
581  //Set if we adjust thresholds
582  bool adjust_thresholds;
583  //Current tree id
584  int tree_id;
585  //Last node id for finding parent
586  int last_node_id;
587  //Need to now the label for interior node visiting
588  vigra::Int32 current_label;
589  //marginal distribution for interior nodes
590  //
592  adjust_thresholds(false), tree_id(0), last_node_id(0), current_label(0)
593  {}
594  struct MarginalDistribution
595  {
596  ArrayVector<Int32> leftCounts;
597  Int32 leftTotalCounts;
598  ArrayVector<Int32> rightCounts;
599  Int32 rightTotalCounts;
600  double gap_left;
601  double gap_right;
602  };
604 
605  //All information for one tree
606  struct TreeOnlineInformation
607  {
608  std::vector<MarginalDistribution> mag_distributions;
609  std::vector<IndexList> index_lists;
610  //map for linear index of mag_distributions
611  std::map<int,int> interior_to_index;
612  //map for linear index of index_lists
613  std::map<int,int> exterior_to_index;
614  };
615 
616  //All trees
617  std::vector<TreeOnlineInformation> trees_online_information;
618 
619  /** Initialize, set the number of trees
620  */
621  template<class RF,class PR>
622  void visit_at_beginning(RF & rf,const PR & pr)
623  {
624  tree_id=0;
625  trees_online_information.resize(rf.options_.tree_count_);
626  }
627 
628  /** Reset a tree
629  */
630  void reset_tree(int tree_id)
631  {
632  trees_online_information[tree_id].mag_distributions.clear();
633  trees_online_information[tree_id].index_lists.clear();
634  trees_online_information[tree_id].interior_to_index.clear();
635  trees_online_information[tree_id].exterior_to_index.clear();
636  }
637 
638  /** simply increase the tree count
639  */
640  template<class RF, class PR, class SM, class ST>
641  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
642  {
643  tree_id++;
644  }
645 
646  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
647  void visit_after_split( Tree & tree,
648  Split & split,
649  Region & parent,
650  Region & leftChild,
651  Region & rightChild,
652  Feature_t & features,
653  Label_t & labels)
654  {
655  int linear_index;
656  int addr=tree.topology_.size();
657  if(split.createNode().typeID() == i_ThresholdNode)
658  {
659  if(adjust_thresholds)
660  {
661  //Store marginal distribution
662  linear_index=trees_online_information[tree_id].mag_distributions.size();
663  trees_online_information[tree_id].interior_to_index[addr]=linear_index;
664  trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
665 
666  trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
667  trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
668 
669  trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
670  trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
671  //Store the gap
672  double gap_left,gap_right;
673  int i;
674  gap_left=features(leftChild[0],split.bestSplitColumn());
675  for(i=1;i<leftChild.size();++i)
676  if(features(leftChild[i],split.bestSplitColumn())>gap_left)
677  gap_left=features(leftChild[i],split.bestSplitColumn());
678  gap_right=features(rightChild[0],split.bestSplitColumn());
679  for(i=1;i<rightChild.size();++i)
680  if(features(rightChild[i],split.bestSplitColumn())<gap_right)
681  gap_right=features(rightChild[i],split.bestSplitColumn());
682  trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
683  trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
684  }
685  }
686  else
687  {
688  //Store index list
689  linear_index=trees_online_information[tree_id].index_lists.size();
690  trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
691 
692  trees_online_information[tree_id].index_lists.push_back(IndexList());
693 
694  trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
695  std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
696  }
697  }
698  void add_to_index_list(int tree,int node,int index)
699  {
700  if(!this->active_)
701  return;
702  TreeOnlineInformation &ti=trees_online_information[tree];
703  ti.index_lists[ti.exterior_to_index[node]].push_back(index);
704  }
705  void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
706  {
707  if(!this->active_)
708  return;
709  trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
710  trees_online_information[src_tree].exterior_to_index.erase(src_index);
711  }
712  /** do something when visiting a internal node during getToLeaf
713  *
714  * remember as last node id, for finding the parent of the last external node
715  * also: adjust class counts and borders
716  */
717  template<class TR, class IntT, class TopT,class Feat>
718  void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
719  {
720  last_node_id=index;
721  if(adjust_thresholds)
722  {
723  vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
724  //Check if we are in the gap
725  double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
726  TreeOnlineInformation &ti=trees_online_information[tree_id];
727  MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
728  if(value>m.gap_left && value<m.gap_right)
729  {
730  //Check which site we want to go
731  if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
732  {
733  //We want to go left
734  m.gap_left=value;
735  }
736  else
737  {
738  //We want to go right
739  m.gap_right=value;
740  }
741  Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
742  }
743  //Adjust class counts
744  if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
745  {
746  ++m.rightTotalCounts;
747  ++m.rightCounts[current_label];
748  }
749  else
750  {
751  ++m.leftTotalCounts;
752  ++m.rightCounts[current_label];
753  }
754  }
755  }
756  /** do something when visiting a extern node during getToLeaf
757  *
758  * Store the new index!
759  */
760 };
761 
762 //////////////////////////////////////////////////////////////////////////////
763 // Out of Bag Error estimates //
764 //////////////////////////////////////////////////////////////////////////////
765 
766 
767 /** Visitor that calculates the oob error of each individual randomized
768  * decision tree.
769  *
770  * After training a tree, all those samples that are OOB for this particular tree
771  * are put down the tree and the error estimated.
772  * the per tree oob error is the average of the individual error estimates.
773  * (oobError = average error of one randomized tree)
774  * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error
775  * visitor)
776  */
778 {
779 public:
780  /** Average error of one randomized decision tree
781  */
782  double oobError;
783 
784  int totalOobCount;
785  ArrayVector<int> oobCount,oobErrorCount;
786 
788  : oobError(0.0),
789  totalOobCount(0)
790  {}
791 
792 
793  bool has_value()
794  {
795  return true;
796  }
797 
798 
799  /** does the basic calculation per tree*/
800  template<class RF, class PR, class SM, class ST>
801  void visit_after_tree( RF& rf, PR & pr, SM & sm, ST & st, int index)
802  {
803  //do the first time called.
804  if(int(oobCount.size()) != rf.ext_param_.row_count_)
805  {
806  oobCount.resize(rf.ext_param_.row_count_, 0);
807  oobErrorCount.resize(rf.ext_param_.row_count_, 0);
808  }
809  // go through the samples
810  for(int l = 0; l < rf.ext_param_.row_count_; ++l)
811  {
812  // if the lth sample is oob...
813  if(!sm.is_used()[l])
814  {
815  ++oobCount[l];
816  if( rf.tree(index)
817  .predictLabel(rowVector(pr.features(), l))
818  != pr.response()(l,0))
819  {
820  ++oobErrorCount[l];
821  }
822  }
823 
824  }
825  }
826 
827  /** Does the normalisation
828  */
829  template<class RF, class PR>
830  void visit_at_end(RF & rf, PR & pr)
831  {
832  // do some normalisation
833  for(int l=0; l < (int)rf.ext_param_.row_count_; ++l)
834  {
835  if(oobCount[l])
836  {
837  oobError += double(oobErrorCount[l]) / oobCount[l];
838  ++totalOobCount;
839  }
840  }
841  oobError/=totalOobCount;
842  }
843 
844 };
845 
846 /** Visitor that calculates the oob error of the ensemble
847  * This rate should be used to estimate the crossvalidation
848  * error rate.
849  * Here each sample is put down those trees, for which this sample
850  * is OOB i.e. if sample #1 is OOB for trees 1, 3 and 5 we calculate
851  * the output using the ensemble consisting only of trees 1 3 and 5.
852  *
853  * Using normal bagged sampling each sample is OOB for approx. 33% of trees
854  * The error rate obtained as such therefore corresponds to crossvalidation
855  * rate obtained using a ensemble containing 33% of the trees.
856  */
857 class OOB_Error : public VisitorBase
858 {
860  int class_count;
861  bool is_weighted;
862  MultiArray<2,double> tmp_prob;
863  public:
864 
865  MultiArray<2, double> prob_oob;
866  /** Ensemble oob error rate
867  */
868  double oob_breiman;
869 
870  MultiArray<2, double> oobCount;
871  ArrayVector< int> indices;
872  OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
873 #ifdef HasHDF5
874  void save(std::string filen, std::string pathn)
875  {
876  if(*(pathn.end()-1) != '/')
877  pathn += "/";
878  const char* filename = filen.c_str();
879  MultiArray<2, double> temp(Shp(1,1), 0.0);
880  temp[0] = oob_breiman;
881  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
882  }
883 #endif
884  // negative value if sample was ib, number indicates how often.
885  // value >=0 if sample was oob, 0 means fail 1, correct
886 
887  template<class RF, class PR>
888  void visit_at_beginning(RF & rf, PR & pr)
889  {
890  class_count = rf.class_count();
891  tmp_prob.reshape(Shp(1, class_count), 0);
892  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
893  is_weighted = rf.options().predict_weighted_;
894  indices.resize(rf.ext_param().row_count_);
895  if(int(oobCount.size()) != rf.ext_param_.row_count_)
896  {
897  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
898  }
899  for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
900  {
901  indices[ii] = ii;
902  }
903  }
904 
905  template<class RF, class PR, class SM, class ST>
906  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
907  {
908  // go through the samples
909  int total_oob =0;
910  // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
911  // (i.e. the OOB sample ist very large)
912  // 40000: use at most 40000 OOB samples per class for OOB error estimate
913  if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
914  {
915  ArrayVector<int> oob_indices;
916  ArrayVector<int> cts(class_count, 0);
917  std::random_shuffle(indices.begin(), indices.end());
918  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
919  {
920  if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
921  {
922  oob_indices.push_back(indices[ii]);
923  ++cts[pr.response()(indices[ii], 0)];
924  }
925  }
926  for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
927  {
928  // update number of trees in which current sample is oob
929  ++oobCount[oob_indices[ll]];
930 
931  // update number of oob samples in this tree.
932  ++total_oob;
933  // get the predicted votes ---> tmp_prob;
934  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
935  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
936  rf.tree(index).parameters_,
937  pos);
938  tmp_prob.init(0);
939  for(int ii = 0; ii < class_count; ++ii)
940  {
941  tmp_prob[ii] = node.prob_begin()[ii];
942  }
943  if(is_weighted)
944  {
945  for(int ii = 0; ii < class_count; ++ii)
946  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
947  }
948  rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
949 
950  }
951  }else
952  {
953  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
954  {
955  // if the lth sample is oob...
956  if(!sm.is_used()[ll])
957  {
958  // update number of trees in which current sample is oob
959  ++oobCount[ll];
960 
961  // update number of oob samples in this tree.
962  ++total_oob;
963  // get the predicted votes ---> tmp_prob;
964  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
965  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
966  rf.tree(index).parameters_,
967  pos);
968  tmp_prob.init(0);
969  for(int ii = 0; ii < class_count; ++ii)
970  {
971  tmp_prob[ii] = node.prob_begin()[ii];
972  }
973  if(is_weighted)
974  {
975  for(int ii = 0; ii < class_count; ++ii)
976  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
977  }
978  rowVector(prob_oob, ll) += tmp_prob;
979  }
980  }
981  }
982  // go through the ib samples;
983  }
984 
985  /** Normalise variable importance after the number of trees is known.
986  */
987  template<class RF, class PR>
988  void visit_at_end(RF & rf, PR & pr)
989  {
990  // ullis original metric and breiman style stuff
991  int totalOobCount =0;
992  int breimanstyle = 0;
993  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
994  {
995  if(oobCount[ll])
996  {
997  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
998  ++breimanstyle;
999  ++totalOobCount;
1000  }
1001  }
1002  oob_breiman = double(breimanstyle)/totalOobCount;
1003  }
1004 };
1005 
1006 
1007 /** Visitor that calculates different OOB error statistics
1008  */
1010 {
1011  typedef MultiArrayShape<2>::type Shp;
1012  int class_count;
1013  bool is_weighted;
1014  MultiArray<2,double> tmp_prob;
1015  public:
1016 
1017  /** OOB Error rate of each individual tree
1018  */
1020  /** Mean of oob_per_tree
1021  */
1022  double oob_mean;
1023  /**Standard deviation of oob_per_tree
1024  */
1025  double oob_std;
1026 
1027  MultiArray<2, double> prob_oob;
1028  /** Ensemble OOB error
1029  *
1030  * \sa OOB_Error
1031  */
1032  double oob_breiman;
1033 
1034  MultiArray<2, double> oobCount;
1035  MultiArray<2, double> oobErrorCount;
1036  /** Per Tree OOB error calculated as in OOB_PerTreeError
1037  * (Ulli's version)
1038  */
1040 
1041  /**Column containing the development of the Ensemble
1042  * error rate with increasing number of trees
1043  */
1045  /** 4 dimensional array containing the development of confusion matrices
1046  * with number of trees - can be used to estimate ROC curves etc.
1047  *
1048  * oobroc_per_tree(ii,jj,kk,ll)
1049  * corresponds true label = ii
1050  * predicted label = jj
1051  * confusion matrix after ll trees
1052  *
1053  * explanation of third index:
1054  *
1055  * Two class case:
1056  * kk = 0 - (treeCount-1)
1057  * Threshold is on Probability for class 0 is kk/(treeCount-1);
1058  * More classes:
1059  * kk = 0. Threshold on probability set by argMax of the probability array.
1060  */
1062 
1064 
1065 #ifdef HasHDF5
1066  /** save to HDF5 file
1067  */
1068  void save(std::string filen, std::string pathn)
1069  {
1070  if(*(pathn.end()-1) != '/')
1071  pathn += "/";
1072  const char* filename = filen.c_str();
1073  MultiArray<2, double> temp(Shp(1,1), 0.0);
1074  writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
1075  writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
1076  writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
1077  temp[0] = oob_mean;
1078  writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
1079  temp[0] = oob_std;
1080  writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
1081  temp[0] = oob_breiman;
1082  writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
1083  temp[0] = oob_per_tree2;
1084  writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
1085  }
1086 #endif
1087  // negative value if sample was ib, number indicates how often.
1088  // value >=0 if sample was oob, 0 means fail 1, correct
1089 
1090  template<class RF, class PR>
1091  void visit_at_beginning(RF & rf, PR & pr)
1092  {
1093  class_count = rf.class_count();
1094  if(class_count == 2)
1095  oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
1096  else
1097  oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
1098  tmp_prob.reshape(Shp(1, class_count), 0);
1099  prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
1100  is_weighted = rf.options().predict_weighted_;
1101  oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1102  breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
1103  //do the first time called.
1104  if(int(oobCount.size()) != rf.ext_param_.row_count_)
1105  {
1106  oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
1107  oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
1108  }
1109  }
1110 
1111  template<class RF, class PR, class SM, class ST>
1112  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1113  {
1114  // go through the samples
1115  int total_oob =0;
1116  int wrong_oob =0;
1117  for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
1118  {
1119  // if the lth sample is oob...
1120  if(!sm.is_used()[ll])
1121  {
1122  // update number of trees in which current sample is oob
1123  ++oobCount[ll];
1124 
1125  // update number of oob samples in this tree.
1126  ++total_oob;
1127  // get the predicted votes ---> tmp_prob;
1128  int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
1129  Node<e_ConstProbNode> node ( rf.tree(index).topology_,
1130  rf.tree(index).parameters_,
1131  pos);
1132  tmp_prob.init(0);
1133  for(int ii = 0; ii < class_count; ++ii)
1134  {
1135  tmp_prob[ii] = node.prob_begin()[ii];
1136  }
1137  if(is_weighted)
1138  {
1139  for(int ii = 0; ii < class_count; ++ii)
1140  tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
1141  }
1142  rowVector(prob_oob, ll) += tmp_prob;
1143  int label = argMax(tmp_prob);
1144 
1145  if(label != pr.response()(ll, 0))
1146  {
1147  // update number of wrong oob samples in this tree.
1148  ++wrong_oob;
1149  // update number of trees in which current sample is wrong oob
1150  ++oobErrorCount[ll];
1151  }
1152  }
1153  }
1154  int breimanstyle = 0;
1155  int totalOobCount = 0;
1156  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1157  {
1158  if(oobCount[ll])
1159  {
1160  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1161  ++breimanstyle;
1162  ++totalOobCount;
1163  if(oobroc_per_tree.shape(2) == 1)
1164  {
1165  oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
1166  }
1167  }
1168  }
1169  if(oobroc_per_tree.shape(2) == 1)
1170  oobroc_per_tree.bindOuter(index)/=totalOobCount;
1171  if(oobroc_per_tree.shape(2) > 1)
1172  {
1173  MultiArrayView<3, double> current_roc
1174  = oobroc_per_tree.bindOuter(index);
1175  for(int gg = 0; gg < current_roc.shape(2); ++gg)
1176  {
1177  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1178  {
1179  if(oobCount[ll])
1180  {
1181  int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
1182  1 : 0;
1183  current_roc(pr.response()(ll, 0), pred, gg)+= 1;
1184  }
1185  }
1186  current_roc.bindOuter(gg)/= totalOobCount;
1187  }
1188  }
1189  breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
1190  oob_per_tree[index] = double(wrong_oob)/double(total_oob);
1191  // go through the ib samples;
1192  }
1193 
1194  /** Normalise variable importance after the number of trees is known.
1195  */
1196  template<class RF, class PR>
1197  void visit_at_end(RF & rf, PR & pr)
1198  {
1199  // ullis original metric and breiman style stuff
1200  oob_per_tree2 = 0;
1201  int totalOobCount =0;
1202  int breimanstyle = 0;
1203  for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
1204  {
1205  if(oobCount[ll])
1206  {
1207  if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
1208  ++breimanstyle;
1209  oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
1210  ++totalOobCount;
1211  }
1212  }
1213  oob_per_tree2 /= totalOobCount;
1214  oob_breiman = double(breimanstyle)/totalOobCount;
1215  // mean error of each tree
1216  MultiArrayView<2, double> mean(Shp(1,1), &oob_mean);
1217  MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
1218  rowStatistics(oob_per_tree, mean, stdDev);
1219  }
1220 };
1221 
1222 /** calculate variable importance while learning.
1223  */
1225 {
1226  public:
1227 
1228  /** This Array has the same entries as the R - random forest variable
1229  * importance.
1230  * Matrix is featureCount by (classCount +2)
1231  * variable_importance_(ii,jj) is the variable importance measure of
1232  * the ii-th variable according to:
1233  * jj = 0 - (classCount-1)
1234  * classwise permutation importance
1235  * jj = rowCount(variable_importance_) -2
1236  * permutation importance
1237  * jj = rowCount(variable_importance_) -1
1238  * gini decrease importance.
1239  *
1240  * permutation importance:
1241  * The difference between the fraction of OOB samples classified correctly
1242  * before and after permuting (randomizing) the ii-th column is calculated.
1243  * The ii-th column is permuted rep_cnt times.
1244  *
1245  * class wise permutation importance:
1246  * same as permutation importance. We only look at those OOB samples whose
1247  * response corresponds to class jj.
1248  *
1249  * gini decrease importance:
1250  * row ii corresponds to the sum of all gini decreases induced by variable ii
1251  * in each node of the random forest.
1252  */
1254  int repetition_count_;
1255  bool in_place_;
1256 
1257 #ifdef HasHDF5
1258  void save(std::string filename, std::string prefix)
1259  {
1260  prefix = "variable_importance_" + prefix;
1261  writeHDF5(filename.c_str(),
1262  prefix.c_str(),
1264  }
1265 #endif
1266 
1267  /* Constructor
1268  * \param rep_cnt (defautl: 10) how often should
1269  * the permutation take place. Set to 1 to make calculation faster (but
1270  * possibly more instable)
1271  */
1272  VariableImportanceVisitor(int rep_cnt = 10)
1273  : repetition_count_(rep_cnt)
1274 
1275  {}
1276 
1277  /** calculates impurity decrease based variable importance after every
1278  * split.
1279  */
1280  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1281  void visit_after_split( Tree & tree,
1282  Split & split,
1283  Region & parent,
1284  Region & leftChild,
1285  Region & rightChild,
1286  Feature_t & features,
1287  Label_t & labels)
1288  {
1289  //resize to right size when called the first time
1290 
1291  Int32 const class_count = tree.ext_param_.class_count_;
1292  Int32 const column_count = tree.ext_param_.column_count_;
1293  if(variable_importance_.size() == 0)
1294  {
1295 
1297  .reshape(MultiArrayShape<2>::type(column_count,
1298  class_count+2));
1299  }
1300 
1301  if(split.createNode().typeID() == i_ThresholdNode)
1302  {
1303  Node<i_ThresholdNode> node(split.createNode());
1304  variable_importance_(node.column(),class_count+1)
1305  += split.region_gini_ - split.minGini();
1306  }
1307  }
1308 
1309  /**compute permutation based var imp.
1310  * (Only an Array of size oob_sample_count x 1 is created.
1311  * - apposed to oob_sample_count x feature_count in the other method.
1312  *
1313  * \sa FieldProxy
1314  */
1315  template<class RF, class PR, class SM, class ST>
1316  void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & st, int index)
1317  {
1318  typedef MultiArrayShape<2>::type Shp_t;
1319  Int32 column_count = rf.ext_param_.column_count_;
1320  Int32 class_count = rf.ext_param_.class_count_;
1321 
1322  /* This solution saves memory uptake but not multithreading
1323  * compatible
1324  */
1325  // remove the const cast on the features (yep , I know what I am
1326  // doing here.) data is not destroyed.
1327  //typename PR::Feature_t & features
1328  // = const_cast<typename PR::Feature_t &>(pr.features());
1329 
1330  typedef typename PR::FeatureWithMemory_t FeatureArray;
1331  typedef typename FeatureArray::value_type FeatureValue;
1332 
1333  FeatureArray features = pr.features();
1334 
1335  //find the oob indices of current tree.
1336  ArrayVector<Int32> oob_indices;
1338  iter;
1339  for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1340  if(!sm.is_used()[ii])
1341  oob_indices.push_back(ii);
1342 
1343  //create space to back up a column
1344  ArrayVector<FeatureValue> backup_column;
1345 
1346  // Random foo
1347 #ifdef CLASSIFIER_TEST
1348  RandomMT19937 random(1);
1349 #else
1350  RandomMT19937 random(RandomSeed);
1351 #endif
1353  randint(random);
1354 
1355 
1356  //make some space for the results
1358  oob_right(Shp_t(1, class_count + 1));
1360  perm_oob_right (Shp_t(1, class_count + 1));
1361 
1362 
1363  // get the oob success rate with the original samples
1364  for(iter = oob_indices.begin();
1365  iter != oob_indices.end();
1366  ++iter)
1367  {
1368  if(rf.tree(index)
1369  .predictLabel(rowVector(features, *iter))
1370  == pr.response()(*iter, 0))
1371  {
1372  //per class
1373  ++oob_right[pr.response()(*iter,0)];
1374  //total
1375  ++oob_right[class_count];
1376  }
1377  }
1378  //get the oob rate after permuting the ii'th dimension.
1379  for(int ii = 0; ii < column_count; ++ii)
1380  {
1381  perm_oob_right.init(0.0);
1382  //make backup of original column
1383  backup_column.clear();
1384  for(iter = oob_indices.begin();
1385  iter != oob_indices.end();
1386  ++iter)
1387  {
1388  backup_column.push_back(features(*iter,ii));
1389  }
1390 
1391  //get the oob rate after permuting the ii'th dimension.
1392  for(int rr = 0; rr < repetition_count_; ++rr)
1393  {
1394  //permute dimension.
1395  int n = oob_indices.size();
1396  for(int jj = 1; jj < n; ++jj)
1397  std::swap(features(oob_indices[jj], ii),
1398  features(oob_indices[randint(jj+1)], ii));
1399 
1400  //get the oob success rate after permuting
1401  for(iter = oob_indices.begin();
1402  iter != oob_indices.end();
1403  ++iter)
1404  {
1405  if(rf.tree(index)
1406  .predictLabel(rowVector(features, *iter))
1407  == pr.response()(*iter, 0))
1408  {
1409  //per class
1410  ++perm_oob_right[pr.response()(*iter, 0)];
1411  //total
1412  ++perm_oob_right[class_count];
1413  }
1414  }
1415  }
1416 
1417 
1418  //normalise and add to the variable_importance array.
1419  perm_oob_right /= repetition_count_;
1420  perm_oob_right -=oob_right;
1421  perm_oob_right *= -1;
1422  perm_oob_right /= oob_indices.size();
1424  .subarray(Shp_t(ii,0),
1425  Shp_t(ii+1,class_count+1)) += perm_oob_right;
1426  //copy back permuted dimension
1427  for(int jj = 0; jj < int(oob_indices.size()); ++jj)
1428  features(oob_indices[jj], ii) = backup_column[jj];
1429  }
1430  }
1431 
1432  /** calculate permutation based impurity after every tree has been
1433  * learned default behaviour is that this happens out of place.
1434  * If you have very big data sets and want to avoid copying of data
1435  * set the in_place_ flag to true.
1436  */
1437  template<class RF, class PR, class SM, class ST>
1438  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1439  {
1440  after_tree_ip_impl(rf, pr, sm, st, index);
1441  }
1442 
1443  /** Normalise variable importance after the number of trees is known.
1444  */
1445  template<class RF, class PR>
1446  void visit_at_end(RF & rf, PR & pr)
1447  {
1448  variable_importance_ /= rf.trees_.size();
1449  }
1450 };
1451 
1452 /** Verbose output
1453  */
1455  public:
1457 
1458  template<class RF, class PR, class SM, class ST>
1459  void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index){
1460  if(index != rf.options().tree_count_-1) {
1461  std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
1462  << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
1463  }
1464  else {
1465  std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
1466  }
1467  }
1468 
1469  template<class RF, class PR>
1470  void visit_at_end(RF const & rf, PR const & pr) {
1471  std::string a = TOCS;
1472  std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl;
1473  }
1474 
1475  template<class RF, class PR>
1476  void visit_at_beginning(RF const & rf, PR const & pr) {
1477  TIC;
1478  std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
1479  }
1480 
1481  private:
1482  USETICTOC;
1483 };
1484 
1485 
1486 /** Computes Correlation/Similarity Matrix of features while learning
1487  * random forest.
1488  */
1490 {
1491  public:
1492  /** gini_missc(ii, jj) describes how well variable jj can describe a partition
1493  * created on variable ii(when variable ii was chosen)
1494  */
1496  MultiArray<2, int> tmp_labels;
1497  /** additional noise features.
1498  */
1500  MultiArray<2, double> noise_l;
1501  /** how well can a noise column describe a partition created on variable ii.
1502  */
1504  MultiArray<2, double> corr_l;
1505 
1506  /** Similarity Matrix
1507  *
1508  * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
1509  * gini_missc
1510  * - row normalized by the number of times the column was chosen
1511  * - mean of corr_noise subtracted
1512  * - and symmetrised.
1513  *
1514  */
1516  /** Distance Matrix 1-similarity
1517  */
1519  ArrayVector<int> tmp_cc;
1520 
1521  /** How often was variable ii chosen
1522  */
1526  void save(std::string file, std::string prefix)
1527  {
1528  /*
1529  std::string tmp;
1530 #define VAR_WRITE(NAME) \
1531  tmp = #NAME;\
1532  tmp += "_";\
1533  tmp += prefix;\
1534  vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
1535  VAR_WRITE(gini_missc);
1536  VAR_WRITE(corr_noise);
1537  VAR_WRITE(distance);
1538  VAR_WRITE(similarity);
1539  vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
1540 #undef VAR_WRITE
1541 */
1542  }
1543  template<class RF, class PR>
1544  void visit_at_beginning(RF const & rf, PR & pr)
1545  {
1546  typedef MultiArrayShape<2>::type Shp;
1547  int n = rf.ext_param_.column_count_;
1548  gini_missc.reshape(Shp(n +1,n+ 1));
1549  corr_noise.reshape(Shp(n + 1, 10));
1550  corr_l.reshape(Shp(n +1, 10));
1551 
1552  noise.reshape(Shp(pr.features().shape(0), 10));
1553  noise_l.reshape(Shp(pr.features().shape(0), 10));
1554  RandomMT19937 random(RandomSeed);
1555  for(int ii = 0; ii < noise.size(); ++ii)
1556  {
1557  noise[ii] = random.uniform53();
1558  noise_l[ii] = random.uniform53() > 0.5;
1559  }
1560  bgfunc = ColumnDecisionFunctor( rf.ext_param_);
1561  tmp_labels.reshape(pr.response().shape());
1562  tmp_cc.resize(2);
1563  numChoices.resize(n+1);
1564  // look at all axes
1565  }
1566  template<class RF, class PR>
1567  void visit_at_end(RF const & rf, PR const & pr)
1568  {
1569  typedef MultiArrayShape<2>::type Shp;
1572  MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1));
1573  rowStatistics(corr_noise, mean_noise);
1574  mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data());
1575  int rC = similarity.shape(0);
1576  for(int jj = 0; jj < rC-1; ++jj)
1577  {
1578  rowVector(similarity, jj) /= numChoices[jj];
1579  rowVector(similarity, jj) -= mean_noise(jj, 0);
1580  }
1581  for(int jj = 0; jj < rC; ++jj)
1582  {
1583  similarity(rC -1, jj) /= numChoices[jj];
1584  }
1585  rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0);
1587  FindMinMax<double> minmax;
1588  inspectMultiArray(srcMultiArrayRange(similarity), minmax);
1589 
1590  for(int jj = 0; jj < rC; ++jj)
1591  similarity(jj, jj) = minmax.max;
1592 
1593  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))
1594  += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
1595  similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;
1596  columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
1597  for(int jj = 0; jj < rC; ++jj)
1598  similarity(jj, jj) = 0;
1599 
1600  FindMinMax<double> minmax2;
1601  inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
1602  for(int jj = 0; jj < rC; ++jj)
1603  similarity(jj, jj) = minmax2.max;
1604  distance.reshape(gini_missc.shape(), minmax2.max);
1605  distance -= similarity;
1606  }
1607 
1608  template<class Tree, class Split, class Region, class Feature_t, class Label_t>
1609  void visit_after_split( Tree & tree,
1610  Split & split,
1611  Region & parent,
1612  Region & leftChild,
1613  Region & rightChild,
1614  Feature_t & features,
1615  Label_t & labels)
1616  {
1617  if(split.createNode().typeID() == i_ThresholdNode)
1618  {
1619  double wgini;
1620  tmp_cc.init(0);
1621  for(int ii = 0; ii < parent.size(); ++ii)
1622  {
1623  tmp_labels[parent[ii]]
1624  = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
1625  ++tmp_cc[tmp_labels[parent[ii]]];
1626  }
1627  double region_gini = bgfunc.loss_of_region(tmp_labels,
1628  parent.begin(),
1629  parent.end(),
1630  tmp_cc);
1631 
1632  int n = split.bestSplitColumn();
1633  ++numChoices[n];
1634  ++(*(numChoices.end()-1));
1635  //this functor does all the work
1636  for(int k = 0; k < features.shape(1); ++k)
1637  {
1638  bgfunc(columnVector(features, k),
1639  tmp_labels,
1640  parent.begin(), parent.end(),
1641  tmp_cc);
1642  wgini = (region_gini - bgfunc.min_gini_);
1643  gini_missc(n, k)
1644  += wgini;
1645  }
1646  for(int k = 0; k < 10; ++k)
1647  {
1648  bgfunc(columnVector(noise, k),
1649  tmp_labels,
1650  parent.begin(), parent.end(),
1651  tmp_cc);
1652  wgini = (region_gini - bgfunc.min_gini_);
1653  corr_noise(n, k)
1654  += wgini;
1655  }
1656 
1657  for(int k = 0; k < 10; ++k)
1658  {
1659  bgfunc(columnVector(noise_l, k),
1660  tmp_labels,
1661  parent.begin(), parent.end(),
1662  tmp_cc);
1663  wgini = (region_gini - bgfunc.min_gini_);
1664  corr_l(n, k)
1665  += wgini;
1666  }
1667  bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
1668  wgini = (region_gini - bgfunc.min_gini_);
1670  += wgini;
1671 
1672  region_gini = split.region_gini_;
1673 #if 1
1674  Node<i_ThresholdNode> node(split.createNode());
1676  node.column())
1677  +=split.region_gini_ - split.minGini();
1678 #endif
1679  for(int k = 0; k < 10; ++k)
1680  {
1681  split.bgfunc(columnVector(noise, k),
1682  labels,
1683  parent.begin(), parent.end(),
1684  parent.classCounts());
1686  k)
1687  += wgini;
1688  }
1689 #if 0
1690  for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
1691  {
1692  wgini = region_gini - split.min_gini_[k];
1693 
1695  split.splitColumns[k])
1696  += wgini;
1697  }
1698 
1699  for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
1700  {
1701  split.bgfunc(columnVector(features, split.splitColumns[k]),
1702  labels,
1703  parent.begin(), parent.end(),
1704  parent.classCounts());
1705  wgini = region_gini - split.bgfunc.min_gini_;
1707  split.splitColumns[k]) += wgini;
1708  }
1709 #endif
1710  // remember to partition the data according to the best.
1712  columnCount(gini_missc)-1)
1713  += region_gini;
1714  SortSamplesByDimensions<Feature_t>
1715  sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
1716  std::partition(parent.begin(), parent.end(), sorter);
1717  }
1718  }
1719 };
1720 
1721 
1722 } // namespace visitors
1723 } // namespace rf
1724 } // namespace vigra
1725 
1726 //@}
1727 #endif // RF_VISITORS_HXX
#define TIC
Definition: timing.hxx:322
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:210
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:1316
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:725
MultiArray< 2, double > breiman_per_tree
Definition: rf_visitors.hxx:1044
MultiArray< 2, double > gini_missc
Definition: rf_visitors.hxx:1495
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
const difference_type & shape() const
Definition: multi_array.hxx:1551
void visit_at_end(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:176
void visit_at_beginning(RF &rf, const PR &pr)
Definition: rf_visitors.hxx:622
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:166
const_iterator begin() const
Definition: array_vector.hxx:223
double oobError
Definition: rf_visitors.hxx:782
MultiArray< 2, double > similarity
Definition: rf_visitors.hxx:1515
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2738
Definition: rf_visitors.hxx:857
ArrayVector< int > numChoices
Definition: rf_visitors.hxx:1523
Definition: rf_visitors.hxx:1489
Definition: rf_visitors.hxx:1224
MultiArrayView< N, T, StridedArrayTag > transpose() const
Definition: multi_array.hxx:1470
double oob_per_tree2
Definition: rf_visitors.hxx:1039
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:339
void reset_tree(int tree_id)
Definition: rf_visitors.hxx:630
Definition: random.hxx:669
difference_type_1 size() const
Definition: multi_array.hxx:1544
MultiArray< 4, double > oobroc_per_tree
Definition: rf_visitors.hxx:1061
double return_val()
Definition: rf_visitors.hxx:220
void visit_at_end(RF &rf, PR &pr)
Definition: rf_visitors.hxx:830
Definition: rf_visitors.hxx:249
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
MultiArray< 2, double > noise
Definition: rf_visitors.hxx:1499
void init(U const &initial)
Definition: array_vector.hxx:146
Definition: rf_split.hxx:831
MultiArray & init(const U &init)
Definition: multi_array.hxx:2728
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
Definition: rf_visitors.hxx:1009
Definition: rf_visitors.hxx:578
MultiArray< 2, double > oob_per_tree
Definition: rf_visitors.hxx:1019
#define TOCS
Definition: timing.hxx:325
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
void writeHDF5(...)
Store array data in an HDF5 file.
Definition: rf_visitors.hxx:1454
MultiArray< 2, double > distance
Definition: rf_visitors.hxx:1518
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
double oob_std
Definition: rf_visitors.hxx:1025
Definition: rf_visitors.hxx:106
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
image import and export functions
Definition: random.hxx:336
MultiArray< 2, double > variable_importance_
Definition: rf_visitors.hxx:1253
double oob_breiman
Definition: rf_visitors.hxx:868
const_iterator end() const
Definition: array_vector.hxx:237
const_pointer data() const
Definition: array_vector.hxx:209
size_type size() const
Definition: array_vector.hxx:330
void visit_at_beginning(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:186
MultiArrayView subarray(difference_type p, difference_type q) const
Definition: multi_array.hxx:1431
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array.
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:202
Definition: rf_visitors.hxx:777
void rowStatistics(...)
MultiArrayView< N-M, T, StrideTag > bindOuter(const TinyVector< Index, M > &d) const
Definition: multi_array.hxx:2067
Definition: rf_visitors.hxx:229
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition: rf_visitors.hxx:147
double oob_mean
Definition: rf_visitors.hxx:1022
MultiArray< 2, double > corr_noise
Definition: rf_visitors.hxx:1503

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.10.0 (Mon Apr 28 2014)