SHOGUN  v3.2.0
ScatterSVM.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2009 Soeren Sonnenburg
8  * Written (W) 2009 Marius Kloft
9  * Copyright (C) 2009 TU Berlin and Max-Planck-Society
10  */
12 
13 
14 
15 #include <shogun/kernel/Kernel.h>
18 #include <shogun/io/SGIO.h>
19 
20 using namespace shogun;
21 
24  model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
25 {
26  SG_UNSTABLE("CScatterSVM::CScatterSVM()", "\n")
27 }
28 
31  norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
32 {
33 }
34 
37  norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
38 {
39 }
40 
42 {
43  SG_FREE(norm_wc);
44  SG_FREE(norm_wcw);
45 }
46 
48 {
51 
53  int32_t num_vectors = m_labels->get_num_labels();
54 
55  if (data)
56  {
57  if (m_labels->get_num_labels() != data->get_num_vectors())
58  SG_ERROR("Number of training vectors does not match number of labels\n")
59  m_kernel->init(data, data);
60  }
61 
62  int32_t* numc=SG_MALLOC(int32_t, m_num_classes);
64 
65  for (int32_t i=0; i<num_vectors; i++)
66  numc[(int32_t) ((CMulticlassLabels*) m_labels)->get_int_label(i)]++;
67 
68  int32_t Nc=0;
69  int32_t Nmin=num_vectors;
70  for (int32_t i=0; i<m_num_classes; i++)
71  {
72  if (numc[i]>0)
73  {
74  Nc++;
75  Nmin=CMath::min(Nmin, numc[i]);
76  }
77 
78  }
79  SG_FREE(numc);
80  m_num_classes=Nc;
81 
82  bool result=false;
83 
85  {
86  result=train_no_bias_libsvm();
87  }
88 
90  {
91  float64_t nu_min=((float64_t) Nc)/num_vectors;
92  float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors;
93 
94  SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max)
95 
96  if (get_nu()<nu_min || get_nu()>nu_max)
97  SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max)
98 
99  result=train_testrule12();
100  }
101  else
102  SG_ERROR("Unknown Scatter type\n")
103 
104  return result;
105 }
106 
107 bool CScatterSVM::train_no_bias_libsvm()
108 {
109  struct svm_node* x_space;
110 
112  SG_INFO("%d trainlabels\n", problem.l)
113 
114  problem.y=SG_MALLOC(float64_t, problem.l);
115  problem.x=SG_MALLOC(struct svm_node*, problem.l);
116  x_space=SG_MALLOC(struct svm_node, 2*problem.l);
117 
118  for (int32_t i=0; i<problem.l; i++)
119  {
120  problem.y[i]=+1;
121  problem.x[i]=&x_space[2*i];
122  x_space[2*i].index=i;
123  x_space[2*i+1].index=-1;
124  }
125 
126  int32_t weights_label[2]={-1,+1};
127  float64_t weights[2]={1.0,get_C()/get_C()};
128 
131 
132  param.svm_type=C_SVC; // Nu MC SVM
133  param.kernel_type = LINEAR;
134  param.degree = 3;
135  param.gamma = 0; // 1/k
136  param.coef0 = 0;
137  param.nu = get_nu(); // Nu
138  CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer();
140  m_num_classes-1, -1, m_labels, prev_normalizer));
141  param.kernel=m_kernel;
142  param.cache_size = m_kernel->get_cache_size();
143  param.C = 0;
144  param.eps = get_epsilon();
145  param.p = 0.1;
146  param.shrinking = 0;
147  param.nr_weight = 2;
148  param.weight_label = weights_label;
149  param.weight = weights;
150  param.nr_class=m_num_classes;
151  param.use_bias = svm_proto()->get_bias_enabled();
152 
153  const char* error_msg = svm_check_parameter(&problem,&param);
154 
155  if(error_msg)
156  SG_ERROR("Error: %s\n",error_msg)
157 
158  model = svm_train(&problem, &param);
159  m_kernel->set_normalizer(prev_normalizer);
160  SG_UNREF(prev_normalizer);
161 
162  if (model)
163  {
164  ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef))
165 
166  ASSERT(model->nr_class==m_num_classes)
168 
169  rho=model->rho[0];
170 
171  SG_FREE(norm_wcw);
173 
174  for (int32_t i=0; i<m_num_classes; i++)
175  {
176  int32_t num_sv=model->nSV[i];
177 
178  CSVM* svm=new CSVM(num_sv);
179  svm->set_bias(model->rho[i+1]);
180  norm_wcw[i]=model->normwcw[i];
181 
182 
183  for (int32_t j=0; j<num_sv; j++)
184  {
185  svm->set_alpha(j, model->sv_coef[i][j]);
186  svm->set_support_vector(j, model->SV[i][j].index);
187  }
188 
189  set_svm(i, svm);
190  }
191 
192  SG_FREE(problem.x);
193  SG_FREE(problem.y);
194  SG_FREE(x_space);
195  for (int32_t i=0; i<m_num_classes; i++)
196  {
197  SG_FREE(model->SV[i]);
198  model->SV[i]=NULL;
199  }
200  svm_destroy_model(model);
201 
203  compute_norm_wc();
204 
205  model=NULL;
206  return true;
207  }
208  else
209  return false;
210 }
211 
212 
213 
214 bool CScatterSVM::train_testrule12()
215 {
216  struct svm_node* x_space;
218  SG_INFO("%d trainlabels\n", problem.l)
219 
220  problem.y=SG_MALLOC(float64_t, problem.l);
221  problem.x=SG_MALLOC(struct svm_node*, problem.l);
222  x_space=SG_MALLOC(struct svm_node, 2*problem.l);
223 
224  for (int32_t i=0; i<problem.l; i++)
225  {
226  problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i);
227  problem.x[i]=&x_space[2*i];
228  x_space[2*i].index=i;
229  x_space[2*i+1].index=-1;
230  }
231 
232  int32_t weights_label[2]={-1,+1};
233  float64_t weights[2]={1.0,get_C()/get_C()};
234 
237 
238  param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM
239  param.kernel_type = LINEAR;
240  param.degree = 3;
241  param.gamma = 0; // 1/k
242  param.coef0 = 0;
243  param.nu = get_nu(); // Nu
244  param.kernel=m_kernel;
245  param.cache_size = m_kernel->get_cache_size();
246  param.C = 0;
247  param.eps = get_epsilon();
248  param.p = 0.1;
249  param.shrinking = 0;
250  param.nr_weight = 2;
251  param.weight_label = weights_label;
252  param.weight = weights;
253  param.nr_class=m_num_classes;
254  param.use_bias = svm_proto()->get_bias_enabled();
255 
256  const char* error_msg = svm_check_parameter(&problem,&param);
257 
258  if(error_msg)
259  SG_ERROR("Error: %s\n",error_msg)
260 
261  model = svm_train(&problem, &param);
262 
263  if (model)
264  {
265  ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef))
266 
267  ASSERT(model->nr_class==m_num_classes)
269 
270  rho=model->rho[0];
271 
272  SG_FREE(norm_wcw);
274 
275  for (int32_t i=0; i<m_num_classes; i++)
276  {
277  int32_t num_sv=model->nSV[i];
278 
279  CSVM* svm=new CSVM(num_sv);
280  svm->set_bias(model->rho[i+1]);
281  norm_wcw[i]=model->normwcw[i];
282 
283 
284  for (int32_t j=0; j<num_sv; j++)
285  {
286  svm->set_alpha(j, model->sv_coef[i][j]);
287  svm->set_support_vector(j, model->SV[i][j].index);
288  }
289 
290  set_svm(i, svm);
291  }
292 
293  SG_FREE(problem.x);
294  SG_FREE(problem.y);
295  SG_FREE(x_space);
296  for (int32_t i=0; i<m_num_classes; i++)
297  {
298  SG_FREE(model->SV[i]);
299  model->SV[i]=NULL;
300  }
301  svm_destroy_model(model);
302 
304  compute_norm_wc();
305 
306  model=NULL;
307  return true;
308  }
309  else
310  return false;
311 }
312 
313 void CScatterSVM::compute_norm_wc()
314 {
315  SG_FREE(norm_wc);
316  norm_wc = SG_MALLOC(float64_t, m_machines->get_num_elements());
317  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
318  norm_wc[i]=0;
319 
320 
321  for (int c=0; c<m_machines->get_num_elements(); c++)
322  {
323  CSVM* svm=get_svm(c);
324  int32_t num_sv = svm->get_num_support_vectors();
325 
326  for (int32_t i=0; i<num_sv; i++)
327  {
328  int32_t ii=svm->get_support_vector(i);
329  for (int32_t j=0; j<num_sv; j++)
330  {
331  int32_t jj=svm->get_support_vector(j);
332  norm_wc[c]+=svm->get_alpha(i)*m_kernel->kernel(ii,jj)*svm->get_alpha(j);
333  }
334  }
335  }
336 
337  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
338  norm_wc[i]=CMath::sqrt(norm_wc[i]);
339 
341 }
342 
344 {
345  CMulticlassLabels* output=NULL;
346  if (!m_kernel)
347  {
348  SG_ERROR("SVM can not proceed without kernel!\n")
349  return NULL;
350  }
351 
353  return NULL;
354 
355  int32_t num_vectors=m_kernel->get_num_vec_rhs();
356 
357  output=new CMulticlassLabels(num_vectors);
358  SG_REF(output);
359 
360  if (scatter_type == TEST_RULE1)
361  {
363  for (int32_t i=0; i<num_vectors; i++)
364  output->set_label(i, apply_one(i));
365  }
366 
367  else
368  {
370  ASSERT(num_vectors==output->get_num_labels())
371  CLabels** outputs=SG_MALLOC(CLabels*, m_machines->get_num_elements());
372 
373  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
374  {
375  //SG_PRINT("svm %d\n", i)
376  CSVM *svm = get_svm(i);
377  ASSERT(svm)
378  svm->set_kernel(m_kernel);
379  svm->set_labels(m_labels);
380  outputs[i]=svm->apply();
381  SG_UNREF(svm);
382  }
383 
384  for (int32_t i=0; i<num_vectors; i++)
385  {
386  int32_t winner=0;
387  float64_t max_out=((CRegressionLabels*) outputs[0])->get_label(i)/norm_wc[0];
388 
389  for (int32_t j=1; j<m_machines->get_num_elements(); j++)
390  {
391  float64_t out=((CRegressionLabels*) outputs[j])->get_label(i)/norm_wc[j];
392 
393  if (out>max_out)
394  {
395  winner=j;
396  max_out=out;
397  }
398  }
399 
400  output->set_label(i, winner);
401  }
402 
403  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
404  SG_UNREF(outputs[i]);
405 
406  SG_FREE(outputs);
407  }
408 
409  return output;
410 }
411 
413 {
415  float64_t* outputs=SG_MALLOC(float64_t, m_machines->get_num_elements());
416  int32_t winner=0;
417 
418  if (scatter_type == TEST_RULE1)
419  {
420  for (int32_t c=0; c<m_machines->get_num_elements(); c++)
421  outputs[c]=get_svm(c)->get_bias()-rho;
422 
423  for (int32_t c=0; c<m_machines->get_num_elements(); c++)
424  {
425  float64_t v=0;
426 
427  for (int32_t i=0; i<get_svm(c)->get_num_support_vectors(); i++)
428  {
429  float64_t alpha=get_svm(c)->get_alpha(i);
430  int32_t svidx=get_svm(c)->get_support_vector(i);
431  v += alpha*m_kernel->kernel(svidx, num);
432  }
433 
434  outputs[c] += v;
435  for (int32_t j=0; j<m_machines->get_num_elements(); j++)
436  outputs[j] -= v/m_machines->get_num_elements();
437  }
438 
439  for (int32_t j=0; j<m_machines->get_num_elements(); j++)
440  outputs[j]/=norm_wcw[j];
441 
442  float64_t max_out=outputs[0];
443  for (int32_t j=0; j<m_machines->get_num_elements(); j++)
444  {
445  if (outputs[j]>max_out)
446  {
447  max_out=outputs[j];
448  winner=j;
449  }
450  }
451  }
452 
453  else
454  {
455  float64_t max_out=get_svm(0)->apply_one(num)/norm_wc[0];
456 
457  for (int32_t i=1; i<m_machines->get_num_elements(); i++)
458  {
459  outputs[i]=get_svm(i)->apply_one(num)/norm_wc[i];
460  if (outputs[i]>max_out)
461  {
462  winner=i;
463  max_out=outputs[i];
464  }
465  }
466  }
467 
468  SG_FREE(outputs);
469  return winner;
470 }
virtual float64_t apply_one(int32_t num)
virtual bool init(CFeatures *lhs, CFeatures *rhs)
Definition: Kernel.cpp:83
virtual bool train_machine(CFeatures *data=NULL)
Definition: ScatterSVM.cpp:47
#define SG_INFO(...)
Definition: SGIO.h:120
virtual ELabelType get_label_type() const =0
Real Labels are real-valued labels.
float64_t * norm_wcw
Definition: ScatterSVM.h:125
virtual float64_t apply_one(int32_t num)
Definition: ScatterSVM.cpp:412
virtual int32_t get_num_labels() const
no bias w/ libsvm
Definition: ScatterSVM.h:28
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:35
virtual int32_t get_num_labels() const =0
multi-class labels 0,1,...
Definition: LabelTypes.h:16
#define SG_UNREF(x)
Definition: SGRefObject.h:35
virtual bool set_normalizer(CKernelNormalizer *normalizer)
Definition: Kernel.cpp:124
virtual int32_t get_num_vectors() const =0
CLabels * m_labels
Definition: Machine.h:356
#define SG_ERROR(...)
Definition: SGIO.h:131
float64_t kernel(int32_t idx_a, int32_t idx_b)
Definition: Kernel.h:198
float64_t * norm_wc
Definition: ScatterSVM.h:122
virtual int32_t get_num_vec_lhs()
Definition: Kernel.h:355
bool set_label(int32_t idx, float64_t label)
Multiclass Labels for multi-class classification.
virtual CKernelNormalizer * get_normalizer()
Definition: Kernel.cpp:136
#define ASSERT(x)
Definition: SGIO.h:203
class MultiClassSVM
Definition: MulticlassSVM.h:26
void set_bias(float64_t bias)
CMulticlassStrategy * m_multiclass_strategy
virtual ~CScatterSVM()
Definition: ScatterSVM.cpp:41
double float64_t
Definition: common.h:48
bool set_alpha(int32_t idx, float64_t val)
#define SG_REF(x)
Definition: SGRefObject.h:34
SCATTER_TYPE scatter_type
Definition: ScatterSVM.h:111
float64_t get_alpha(int32_t idx)
the scatter kernel normalizer
bool set_support_vector(int32_t idx, int32_t val)
static void fill_vector(T *vec, int32_t len, T value)
Definition: SGVector.cpp:271
The class Kernel Normalizer defines a function to post-process kernel values.
int32_t get_support_vector(int32_t idx)
virtual int32_t get_num_vec_rhs()
Definition: Kernel.h:364
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:16
SCATTER_TYPE
Definition: ScatterSVM.h:25
training with bias using test rule 2
Definition: ScatterSVM.h:33
The class Features is the base class of all feature objects.
Definition: Features.h:62
training with bias using test rule 1
Definition: ScatterSVM.h:31
bool create_multiclass_svm(int32_t num_classes)
static T min(T a, T b)
return the minimum of two integers
Definition: Math.h:153
A generic Support Vector Machine Interface.
Definition: SVM.h:47
The Kernel base class.
Definition: Kernel.h:150
int32_t get_cache_size()
Definition: Kernel.h:435
void set_kernel(CKernel *k)
svm_parameter param
Definition: ScatterSVM.h:116
multiclass one vs rest strategy used to train generic multiclass machines for K-class problems with b...
bool set_svm(int32_t num, CSVM *svm)
void display_vector(const char *name="vector", const char *prefix="") const
Definition: SGVector.cpp:405
static float32_t sqrt(float32_t x)
x^0.5
Definition: Math.h:245
virtual CLabels * classify_one_vs_rest()
Definition: ScatterSVM.cpp:343
virtual bool has_features()
Definition: Kernel.h:373
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:75
#define SG_UNSTABLE(func,...)
Definition: SGIO.h:134
CSVM * get_svm(int32_t num)
Definition: MulticlassSVM.h:74
svm_problem problem
Definition: ScatterSVM.h:114
struct svm_model * model
Definition: ScatterSVM.h:119
virtual CLabels * apply(CFeatures *data=NULL)
Definition: Machine.cpp:162

SHOGUN Machine Learning Toolbox - Documentation