shark::CrossEntropy Class Reference

Error measure for classication tasks that can be used as the objective function for training. More...

#include <shark/ObjectiveFunctions/Loss/CrossEntropy.h>

+ Inheritance diagram for shark::CrossEntropy:

Public Member Functions

 CrossEntropy ()
 
std::string name () const
 From INameable: return the class name. More...
 
double eval (UIntVector const &target, RealMatrix const &prediction) const
 
double evalDerivative (UIntVector const &target, RealMatrix const &prediction, RealMatrix &gradient) const
 
double evalDerivative (unsigned int const &target, RealVector const &prediction, RealVector &gradient, RealMatrix &hessian) const
 evaluate the loss and its first and second derivative for a target and a prediction More...
 
- Public Member Functions inherited from shark::AbstractLoss< unsigned int, RealVector >
 AbstractLoss ()
 
virtual double eval (BatchLabelType const &target, BatchOutputType const &prediction) const=0
 evaluate the loss for a batch of targets and a prediction More...
 
virtual double eval (LabelType const &target, OutputType const &prediction) const
 evaluate the loss for a target and a prediction More...
 
double eval (Data< LabelType > const &targets, Data< OutputType > const &predictions) const
 
virtual double evalDerivative (LabelType const &target, OutputType const &prediction, OutputType &gradient) const
 evaluate the loss and its derivative for a target and a prediction More...
 
virtual double evalDerivative (BatchLabelType const &target, BatchOutputType const &prediction, BatchOutputType &gradient) const
 evaluate the loss and the derivative w.r.t. the prediction More...
 
double operator() (LabelType const &target, OutputType const &prediction) const
 evaluate the loss for a target and a prediction More...
 
double operator() (BatchLabelType const &target, BatchOutputType const &prediction) const
 
- Public Member Functions inherited from shark::AbstractCost< unsigned int, RealVector >
virtual ~AbstractCost ()
 
const Featuresfeatures () const
 
virtual void updateFeatures ()
 
bool hasFirstDerivative () const
 returns true when the first parameter derivative is implemented More...
 
bool isLossFunction () const
 returns true when the cost function is in fact a loss function More...
 
double operator() (Data< LabelType > const &targets, Data< OutputType > const &predictions) const
 
- Public Member Functions inherited from shark::INameable
virtual ~INameable ()
 

Additional Inherited Members

- Public Types inherited from shark::AbstractLoss< unsigned int, RealVector >
typedef RealVector OutputType
 
typedef unsigned int LabelType
 
typedef VectorMatrixTraits< OutputType >::DenseMatrixType MatrixType
 
typedef Batch< OutputType >::type BatchOutputType
 
typedef Batch< LabelType >::type BatchLabelType
 
- Public Types inherited from shark::AbstractCost< unsigned int, RealVector >
enum  Feature
 list of features a cost function can have More...
 
typedef RealVector OutputType
 
typedef unsigned int LabelType
 
typedef Batch< OutputType >::type BatchOutputType
 
typedef Batch< LabelType >::type BatchLabelType
 
typedef TypedFlags< FeatureFeatures
 
typedef TypedFeatureNotAvailableException< FeatureFeatureNotAvailableException
 
- Protected Attributes inherited from shark::AbstractCost< unsigned int, RealVector >
Features m_features
 

Detailed Description

Error measure for classication tasks that can be used as the objective function for training.

If your model should return a vector whose components reflect the logarithmic conditonal probabilities of class membership given any input vector 'CrossEntropy' is the adequate error measure for model-training. For C>1 classes the loss function is defined as

\[ E = - \ln \frac{\exp{x_c}} {\sum_{c^{\prime}=1}^C \exp{x_c^{\prime}}} = - x_c + \ln \sum_{c^{\prime}=1}^C \exp{x_c^{\prime}} \]

where x is the prediction vector of the model and c is the class label. In the case of only one model output and binary classification, another more numerically stable formulation is used:

\[ E = \ln(1+ e^{-yx}) \]

here, y are class labels between -1 and 1 and y = -2 c+1. The reason why this is numerically more stable is, that when \( e^{-yx} \) is big, the error function is well approximated by the linear function x. Also if the exponential is very small, the case \( \ln(0) \) is avoided.

The class labels must be integers starting from 0. Also for theoretical reasons, the output neurons of a neural Network must be linear.

Definition at line 66 of file CrossEntropy.h.

Constructor & Destructor Documentation

§ CrossEntropy()

Member Function Documentation

§ eval()

double shark::CrossEntropy::eval ( UIntVector const &  target,
RealMatrix const &  prediction 
) const
inline

Definition at line 98 of file CrossEntropy.h.

References shark::blas::max(), RANGE_CHECK, and shark::blas::row().

§ evalDerivative() [1/2]

double shark::CrossEntropy::evalDerivative ( UIntVector const &  target,
RealMatrix const &  prediction,
RealMatrix &  gradient 
) const
inline

§ evalDerivative() [2/2]

double shark::CrossEntropy::evalDerivative ( unsigned int const &  target,
RealVector const &  prediction,
RealVector &  gradient,
RealMatrix &  hessian 
) const
inlinevirtual

evaluate the loss and its first and second derivative for a target and a prediction

Parameters
targettarget value
predictionprediction, typically made by a model
gradientthe gradient of the loss function with respect to the prediction
hessianthe hessian of the loss function with respect to the prediction

Reimplemented from shark::AbstractLoss< unsigned int, RealVector >.

Definition at line 172 of file CrossEntropy.h.

References shark::blas::diag(), shark::blas::max(), shark::blas::noalias(), shark::blas::outer_prod(), RANGE_CHECK, shark::sigmoid(), and shark::blas::sum().

§ name()

std::string shark::CrossEntropy::name ( ) const
inlinevirtual

From INameable: return the class name.

Reimplemented from shark::INameable.

Definition at line 92 of file CrossEntropy.h.

References shark::AbstractLoss< unsigned int, RealVector >::eval().


The documentation for this class was generated from the following file: