32 #ifndef SHARK_ML_SVMLOGISTICINTERPRETATION_H 33 #define SHARK_ML_SVMLOGISTICINTERPRETATION_H 41 #include <boost/math/special_functions/log1p.hpp> 62 template<
class InputType = RealVector>
87 FoldsType
const &folds, KernelType *kernel,
91 , m_nhp(kernel->parameterVector().
size()+1)
92 , m_nkp(kernel->parameterVector().
size())
93 , m_numFolds(folds.
size())
94 , m_numSamples(folds.dataset().numberOfElements())
96 , m_svmCIsUnconstrained(unconstrained)
97 , mep_svmStoppingCondition(stop_cond)
98 , m_sigmoidSlopeIsUnconstrained(true)
100 SHARK_CHECK(kernel != NULL,
"[SvmLogisticInterpretation::SvmLogisticInterpretation] kernel is not allowed to be NULL");
101 SHARK_CHECK(m_numFolds > 1,
"[SvmLogisticInterpretation::SvmLogisticInterpretation] please provide a meaningful number of folds for cross validation");
102 if (!m_svmCIsUnconstrained)
112 {
return "SvmLogisticInterpretation"; }
136 SHARK_CHECK(m_nhp == parameters.size(),
"[SvmLogisticInterpretation::eval] wrong number of parameters");
142 std::vector< unsigned int > tmp_helper_labels(m_numSamples);
143 std::vector< RealVector > tmp_helper_preds(m_numSamples);
144 unsigned int next_label = 0;
158 if (mep_svmStoppingCondition != NULL) {
166 csvm_trainer.
train(svm, cur_train_data);
169 for (std::size_t j=0; j<cur_vsize; j++) {
170 tmp_helper_labels[next_label] = cur_vlabels.
element(j);
171 tmp_helper_preds[next_label] = cur_vscores.
element(j);
180 SigmoidModel sigmoid_model(m_sigmoidSlopeIsUnconstrained);
183 sigmoid_trainer.
train(sigmoid_model, validation_dataset);
189 double p = sigmoid_predictions.
element(i)(0);
190 if (all_validation_labels.
element(i) == 1){
191 error -= std::log(p);
194 error -= boost::math::log1p(-p);
206 SHARK_CHECK(m_nhp == parameters.size(),
"[SvmLogisticInterpretation::evalDerivative] wrong number of parameters");
212 std::vector< unsigned int > tmp_helper_labels(m_numSamples);
213 std::vector< RealVector > tmp_helper_preds(m_numSamples);
215 unsigned int next_label = 0;
217 RealMatrix all_validation_predict_derivs(m_numSamples, m_nhp);
233 if (mep_svmStoppingCondition != NULL) {
239 csvm_trainer.
train(svm, cur_train_data);
243 for (std::size_t j=0; j<cur_vsize; j++) {
245 tmp_helper_labels[next_label] = cur_vlabels.
element(j);
246 tmp_helper_preds[next_label] = cur_vscores.
element(j);
249 noalias(
row(all_validation_predict_derivs, next_label)) = der;
258 SigmoidModel sigmoid_model(m_sigmoidSlopeIsUnconstrained);
261 sigmoid_trainer.
train(sigmoid_model, validation_dataset);
269 derivative.resize(m_nhp);
275 double p = sigmoid_predictions.
element(i)(0);
278 if (all_validation_labels.
element(i) == 1){
279 error -= std::log(p);
283 error -= boost::math::log1p(-p);
284 dL_dsp = 1.0/(1.0-p);
288 double dsp_dsvmp = ss * p * (1.0-p);
289 for (std::size_t j=0; j<
m_nhp; j++) {
290 derivative(j) += dL_dsp * dsp_dsvmp * all_validation_predict_derivs(i,j);