ROC.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief ROC
6  *
7  *
8  *
9  * \author O.Krause
10  * \date 2010-2011
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 #ifndef SHARK_OBJECTIVEFUNCTIONS_ROC_H
35 #define SHARK_OBJECTIVEFUNCTIONS_ROC_H
36 
37 #include <shark/Core/DLLSupport.h>
39 #include <shark/Data/Dataset.h>
40 #include <vector>
41 #include <algorithm>
42 
43 namespace shark {
44 
45 //!
46 //! \brief ROC-Curve - false negatives over false positives
47 //!
48 //! \par
49 //! This class provides the ROC curve of a classifier.
50 //! All time consuming computations are done in the constructor,
51 //! such that afterwards fast access to specific values of the
52 //! curve and the equal error rate is possible.
53 //!
54 //! \par
55 //! The ROC class assumes a one dimensional target array and a
56 //! model producing one dimensional output data. The targets must
57 //! be the labels 0 and 1 of a binary classification task. The
58 //! model output is assumed not to be 0 and 1, but real valued
59 //! instead. Classification in done by thresholding, where
60 //! different false positive and false negative rates correspond
61 //! to different thresholds. The ROC curve shows the trade off
62 //! between the two error types.
63 //!
64 class ROC
65 {
66 public:
67  //! Constructor
68  //!
69  //! \param model model to use for prediction
70  //! \param set data set with inputs and corresponding binary outputs (0 or 1)
71  template<class InputType>
73  std::size_t inputs=set.numberOfElements();
74 
75  //calculat the number of classes
76  std::vector<std::size_t> classes = classSizes(set);
77  SIZE_CHECK(classes.size() == 2); //only binary problems allowed!
78 
79  std::size_t positive = classes[0];
80  std::size_t negative = classes[1];
81  m_scorePositive.resize(positive);
82  m_scoreNegative.resize(negative);
83 
84  // compute scores
85  std::size_t posPositive = 0;
86  std::size_t posNegative = 0;
87 
88  //calculate the model responses batchwise for the whole set
89  for(std::size_t i = 0; i != set.size(); ++i){
90  RealMatrix output = model(set.batch(i).input);
91  SIZE_CHECK(output.size2() == 1);
92  for(std::size_t j = 0; j != size(output); ++j){
93  double value = output(j,0);
94  if (set.batch(i)(j) == 1)
95  {
96  m_scorePositive[posPositive] = value;
97  posPositive++;
98  }
99  else
100  {
101  m_scoreNegative[posNegative] = value;
102  posNegative++;
103  }
104  }
105  }
106  // sort positives and negatives by score
107  std::sort(m_scorePositive.begin(), m_scorePositive.end());
108  std::sort(m_scoreNegative.begin(), m_scoreNegative.end());
109  }
110 
111  //! Compute the threshold for given false acceptance rate,
112  //! that is, for a given false positive rate.
113  //! This threshold, used for classification with the underlying
114  //! model, results in the given false acceptance rate.
115  SHARK_EXPORT_SYMBOL double threshold(double falseAcceptanceRate)const;
116 
117  //! Value of the ROC curve for given false acceptance rate,
118  //! that is, for a given false positive rate.
119  SHARK_EXPORT_SYMBOL double value(double falseAcceptanceRate)const;
120 
121  //! Computes the equal error rate of the classifier
122  SHARK_EXPORT_SYMBOL double equalErrorRate()const;
123 
124 protected:
125  //! scores of the positive examples
126  std::vector<double> m_scorePositive;
127 
128  //! scores of the negative examples
129  std::vector<double> m_scoreNegative;
130 };
131 
132 }
133 #endif