Edinburgh Speech Tools  2.1-release
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
wagon.cc
1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1996,1997 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : May 1996 */
35 /*-----------------------------------------------------------------------*/
36 /* A Classification and Regression Tree (CART) Program */
37 /* A basic implementation of many of the techniques in */
38 /* Briemen et al. 1984 */
39 /* */
40 /* Added decision list support, Feb 1997 */
41 /* Added stepwise use of features, Oct 1997 */
42 /* */
43 /*=======================================================================*/
44 
45 #include <cstdlib>
46 #include <iostream>
47 #include <fstream>
48 #include <cstring>
49 #include "EST_Token.h"
50 #include "EST_FMatrix.h"
51 #include "EST_multistats.h"
52 #include "EST_Wagon.h"
53 #include "EST_math.h"
54 
55 Discretes wgn_discretes;
56 
57 WDataSet wgn_dataset;
58 WDataSet wgn_test_dataset;
59 EST_FMatrix wgn_DistMatrix;
60 EST_Track wgn_VertexTrack;
61 EST_Track wgn_VertexFeats;
62 EST_Track wgn_UnitTrack;
63 
64 int wgn_min_cluster_size = 50;
65 int wgn_held_out = 0;
66 int wgn_prune = TRUE;
67 int wgn_quiet = FALSE;
68 int wgn_verbose = FALSE;
69 int wgn_count_field = -1;
70 EST_String wgn_count_field_name = "";
71 int wgn_predictee = 0;
72 EST_String wgn_predictee_name = "";
73 float wgn_float_range_split = 10;
74 float wgn_balance = 0;
75 EST_String wgn_opt_param = "";
76 EST_String wgn_vertex_output = "mean";
77 EST_String wgn_vertex_otype = "mean";
78 
79 static float do_summary(WNode &tree,WDataSet &ds,ostream *output);
80 static float test_tree_float(WNode &tree,WDataSet &ds,ostream *output);
81 static float test_tree_class(WNode &tree,WDataSet &ds,ostream *output);
82 static float test_tree_cluster(WNode &tree,WDataSet &dataset, ostream *output);
83 static float test_tree_vector(WNode &tree,WDataSet &dataset,ostream *output);
84 static float test_tree_trajectory(WNode &tree,WDataSet &dataset,ostream *output);
85 static int wagon_split(int margin,WNode &node);
86 static WQuestion find_best_question(WVectorVector &dset);
87 static void construct_binary_ques(int feat,WQuestion &test_ques);
88 static float construct_float_ques(int feat,WQuestion &ques,WVectorVector &ds);
89 static float construct_class_ques(int feat,WQuestion &ques,WVectorVector &ds);
90 static void wgn_set_up_data(WVectorVector &data,const WVectorList &ds,int held_out,int in);
91 static WNode *wagon_stepwise_find_next_best(float &bscore,int &best_feat);
92 
93 Declare_TList_T(WVector *, WVectorP)
94 
95 Declare_TVector_Base_T(WVector *,NULL,NULL,WVectorP)
96 
97 #if defined(INSTANTIATE_TEMPLATES)
98 // Instantiate class
99 #include "../base_class/EST_TList.cc"
100 #include "../base_class/EST_TVector.cc"
101 
102 Instantiate_TList_T(WVector *, WVectorP)
103 
104 Instantiate_TVector(WVector *)
105 
106 #endif
107 
108 void wgn_load_datadescription(EST_String fname,LISP ignores)
109 {
110  // Load field description for a file
111  wgn_dataset.load_description(fname,ignores);
112  wgn_test_dataset.load_description(fname,ignores);
113 }
114 
115 void wgn_load_dataset(WDataSet &dataset,EST_String fname)
116 {
117  // Read the data set from a filename. One vector per line
118  // Assume all numbers are numbers and non-nums are categorical
119  EST_TokenStream ts;
120  WVector *v;
121  int nvec=0,i;
122 
123  if (ts.open(fname) == -1)
124  wagon_error(EST_String("unable to open data file \"")+
125  fname+"\"");
126  ts.set_PunctuationSymbols("");
128  ts.set_SingleCharSymbols("");
129 
130  for ( ;!ts.eof(); )
131  {
132  v = new WVector(dataset.width());
133  i = 0;
134  do
135  {
136  int type = dataset.ftype(i);
137  if ((type == wndt_float) || (wgn_count_field == i))
138  {
139  // need to ensure this is not NaN or Infinity
140  float f = atof(ts.get().string());
141  if (finite(f))
142  v->set_flt_val(i,f);
143  else
144  {
145  cout << fname << ": bad float " << f
146  << " in field " <<
147  dataset.feat_name(i) << " vector " <<
148  dataset.samples() << endl;
149  v->set_flt_val(i,0.0);
150  }
151  }
152  else if (type == wndt_binary)
153  v->set_int_val(i,atoi(ts.get().string()));
154  else if (type == wndt_cluster) /* index into distmatrix */
155  v->set_int_val(i,atoi(ts.get().string()));
156  else if (type == wndt_vector) /* index into VertexTrack */
157  v->set_int_val(i,atoi(ts.get().string()));
158  else if (type == wndt_trajectory) /* index to index and length */
159  { /* a number pointing to a vector in UnitTrack that */
160  /* has an idex into VertexTrack and a number of Vertices */
161  /* Thus if its 15, UnitTrack.a(15,0) is the start frame in */
162  /* VertexTrack and UnitTrack.a(15,1) is the number of */
163  /* frames in the unit */
164  v->set_int_val(i,atoi(ts.get().string()));
165  }
166  else if (type == wndt_ignore)
167  {
168  ts.get(); // skip it
169  v->set_int_val(i,0);
170  }
171  else // should check the different classes
172  {
173  EST_String s = ts.get().string();
174  int n = wgn_discretes.discrete(type).name(s);
175  if (n == -1)
176  {
177  cout << fname << ": bad value " << s << " in field " <<
178  dataset.feat_name(i) << " vector " <<
179  dataset.samples() << endl;
180  n = 0;
181  }
182  v->set_int_val(i,n);
183  }
184  i++;
185  }
186  while (!ts.eoln() && i<dataset.width());
187  nvec ++;
188  if (i != dataset.width())
189  {
190  wagon_error(fname+": data vector "+itoString(nvec)+" contains "
191  +itoString(i)+" parameters instead of "+
192  itoString(dataset.width()));
193  }
194  if (!ts.eoln())
195  {
196  cerr << fname << ": data vector " << nvec <<
197  " contains too many parameters instead of "
198  << dataset.width() << endl;
199  wagon_error(EST_String("extra parameter(s) from ")+
200  ts.peek().string());
201  }
202  dataset.append(v);
203  }
204 
205  cout << "Dataset of " << dataset.samples() << " vectors of " <<
206  dataset.width() << " parameters from: " << fname << endl;
207  ts.close();
208 }
209 
210 float summary_results(WNode &tree,ostream *output)
211 {
212  if (wgn_test_dataset.samples() != 0)
213  return do_summary(tree,wgn_test_dataset,output);
214  else
215  return do_summary(tree,wgn_dataset,output);
216 }
217 
218 static float do_summary(WNode &tree,WDataSet &ds,ostream *output)
219 {
220  if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
221  return test_tree_cluster(tree,ds,output);
222  else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
223  return test_tree_vector(tree,ds,output);
224  else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
225  return test_tree_trajectory(tree,ds,output);
226  else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
227  return test_tree_class(tree,ds,output);
228  else
229  return test_tree_float(tree,ds,output);
230 }
231 
232 WNode *wgn_build_tree(float &score)
233 {
234  // Build init node and split it while reducing the impurity
235  WNode *top = new WNode();
236  int margin = 0;
237 
238  wgn_set_up_data(top->get_data(),wgn_dataset,wgn_held_out,TRUE);
239 
240  margin = 0;
241  wagon_split(margin,*top); // recursively split data;
242 
243  if (wgn_held_out > 0)
244  {
245  wgn_set_up_data(top->get_data(),wgn_dataset,wgn_held_out,FALSE);
246  top->held_out_prune();
247  }
248 
249  if (wgn_prune)
250  top->prune();
251 
252  score = summary_results(*top,0);
253 
254  return top;
255 }
256 
257 static void wgn_set_up_data(WVectorVector &data,const WVectorList &ds,int held_out,int in)
258 {
259  // Set data ommitting held_out percent if in is true
260  // or only including 100-held_out percent if in is false
261  int i,j;
262  EST_Litem *d;
263 
264  // Make it definitely big enough
265  data.resize(ds.length());
266 
267  for (j=i=0,d=ds.head(); d != 0; d=d->next(),j++)
268  {
269  if ((in) && ((j%100) >= held_out))
270  data[i++] = ds(d);
271 // else if ((!in) && ((j%100 < held_out)))
272 // data[i++] = ds(d);
273  else if (!in)
274  data[i++] = ds(d);
275 // if ((in) && (j < held_out))
276 // data[i++] = ds(d);
277 // else if ((!in) && (j >=held_out))
278 // data[i++] = ds(d);
279  }
280  // make it the actual size, but don't reset values
281  data.resize(i,1);
282 }
283 
284 static float test_tree_class(WNode &tree,WDataSet &dataset,ostream *output)
285 {
286  // Test tree against data to get summary of results
287  EST_StrStr_KVL pairs;
288  EST_StrList lex;
289  EST_Litem *p;
290  EST_String predict,real;
291  WNode *pnode;
292  double H=0,prob;
293  int i,type;
294  float correct=0,total=0, count=0;
295 
296  for (p=dataset.head(); p != 0; p=p->next())
297  {
298  pnode = tree.predict_node((*dataset(p)));
299  predict = (EST_String)pnode->get_impurity().value();
300  if (wgn_count_field == -1)
301  count = 1.0;
302  else
303  count = dataset(p)->get_flt_val(wgn_count_field);
304  prob = pnode->get_impurity().pd().probability(predict);
305  H += (log(prob))*count;
306  type = dataset.ftype(wgn_predictee);
307  real = wgn_discretes[type].name(dataset(p)->get_int_val(wgn_predictee));
308  if (real == predict)
309  correct += count;
310  total += count;
311  pairs.add_item(real,predict,1);
312  }
313  for (i=0; i<wgn_discretes[dataset.ftype(wgn_predictee)].length(); i++)
314  lex.append(wgn_discretes[dataset.ftype(wgn_predictee)].name(i));
315 
316  const EST_FMatrix &m = confusion(pairs,lex);
317 
318  if (output != NULL)
319  {
320  print_confusion(m,pairs,lex); // should be to output not stdout
321  *output << ";; entropy " << (-1*(H/total)) << " perplexity " <<
322  pow(2.0,(-1*(H/total))) << endl;
323  }
324 
325  // Minus it so bigger is better
326  if (wgn_opt_param == "entropy")
327  return -pow(2.0,(-1*(H/total)));
328  else
329  return (float)correct/(float)total;
330 }
331 
332 static float test_tree_vector(WNode &tree,WDataSet &dataset,ostream *output)
333 {
334  // Test tree against data to get summary of results VECTOR
335  // distance is calculated in zscores (as the values in vector may
336  // have quite different ranges
337  WNode *leaf;
338  EST_Litem *p;
339  float predict, actual;
340  EST_SuffStats x,y,xx,yy,xy,se,e;
341  EST_SuffStats b;
342  int i,j,pos;
343  double cor,error;
344  double count;
345  EST_Litem *pp;
346 
347  for (p=dataset.head(); p != 0; p=p->next())
348  {
349  leaf = tree.predict_node((*dataset(p)));
350  pos = dataset(p)->get_int_val(wgn_predictee);
351  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
352  if (wgn_VertexFeats.a(0,j) > 0.0)
353  {
354  b.reset();
355  for (pp=leaf->get_impurity().members.head(); pp != 0; pp=pp->next())
356  {
357  i = leaf->get_impurity().members.item(pp);
358  b += wgn_VertexTrack.a(i,j);
359  }
360  predict = b.mean();
361  actual = wgn_VertexTrack.a(pos,j);
362  if (wgn_count_field == -1)
363  count = 1.0;
364  else
365  count = dataset(p)->get_flt_val(wgn_count_field);
366  x.cumulate(predict,count);
367  y.cumulate(actual,count);
368  /* Normalized the error by the standard deviation */
369  if (b.stddev() == 0)
370  error = predict-actual;
371  else
372  error = (predict-actual)/b.stddev();
373  error = predict-actual; /* awb_debug */
374  se.cumulate((error*error),count);
375  e.cumulate(fabs(error),count);
376  xx.cumulate(predict*predict,count);
377  yy.cumulate(actual*actual,count);
378  xy.cumulate(predict*actual,count);
379  }
380  }
381 
382  // Pearson's product moment correlation coefficient
383 // cor = (xy.mean() - (x.mean()*y.mean()))/
384 // (sqrt(xx.mean()-(x.mean()*x.mean())) *
385 // sqrt(yy.mean()-(y.mean()*y.mean())));
386  // Because when the variation is X is very small we can
387  // go negative, thus cause the sqrt's to give FPE
388  double v1 = xx.mean()-(x.mean()*x.mean());
389  double v2 = yy.mean()-(y.mean()*y.mean());
390 
391  double v3 = v1*v2;
392 
393  if (v3 <= 0)
394  // happens when there's very little variation in x
395  cor = 0;
396  else
397  cor = (xy.mean() - (x.mean()*y.mean()))/ sqrt(v3);
398 
399  if (output != NULL)
400  {
401  if (output != &cout) // save in output file
402  *output
403  << ";; RMSE " << ftoString(sqrt(se.mean()),4,1)
404  << " Correlation is " << ftoString(cor,4,1)
405  << " Mean (abs) Error " << ftoString(e.mean(),4,1)
406  << " (" << ftoString(e.stddev(),4,1) << ")" << endl;
407 
408  cout << "RMSE " << ftoString(sqrt(se.mean()),4,1)
409  << " Correlation is " << ftoString(cor,4,1)
410  << " Mean (abs) Error " << ftoString(e.mean(),4,1)
411  << " (" << ftoString(e.stddev(),4,1) << ")" << endl;
412  }
413 
414  if (wgn_opt_param == "rmse")
415  return -sqrt(se.mean()); // * -1 so bigger is better
416  else
417  return cor; // should really be % variance, I think
418 }
419 
420 static float test_tree_trajectory(WNode &tree,WDataSet &dataset,ostream *output)
421 {
422  // Test tree against data to get summary of results TRAJECTORY
423  // distance is calculated in zscores (as the values in vector may
424  // have quite different ranges)
425  // NOT WRITTEN YET
426  WNode *leaf;
427  EST_Litem *p;
428  float predict, actual;
429  EST_SuffStats x,y,xx,yy,xy,se,e;
430  EST_SuffStats b;
431  int i,j,pos;
432  double cor,error;
433  double count;
434  EST_Litem *pp;
435 
436  for (p=dataset.head(); p != 0; p=p->next())
437  {
438  leaf = tree.predict_node((*dataset(p)));
439  pos = dataset(p)->get_int_val(wgn_predictee);
440  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
441  if (wgn_VertexFeats.a(0,j) > 0.0)
442  {
443  b.reset();
444  for (pp=leaf->get_impurity().members.head(); pp != 0; pp=pp->next())
445  {
446  i = leaf->get_impurity().members.item(pp);
447  b += wgn_VertexTrack.a(i,j);
448  }
449  predict = b.mean();
450  actual = wgn_VertexTrack.a(pos,j);
451  if (wgn_count_field == -1)
452  count = 1.0;
453  else
454  count = dataset(p)->get_flt_val(wgn_count_field);
455  x.cumulate(predict,count);
456  y.cumulate(actual,count);
457  /* Normalized the error by the standard deviation */
458  if (b.stddev() == 0)
459  error = predict-actual;
460  else
461  error = (predict-actual)/b.stddev();
462  error = predict-actual; /* awb_debug */
463  se.cumulate((error*error),count);
464  e.cumulate(fabs(error),count);
465  xx.cumulate(predict*predict,count);
466  yy.cumulate(actual*actual,count);
467  xy.cumulate(predict*actual,count);
468  }
469  }
470 
471  // Pearson's product moment correlation coefficient
472 // cor = (xy.mean() - (x.mean()*y.mean()))/
473 // (sqrt(xx.mean()-(x.mean()*x.mean())) *
474 // sqrt(yy.mean()-(y.mean()*y.mean())));
475  // Because when the variation is X is very small we can
476  // go negative, thus cause the sqrt's to give FPE
477  double v1 = xx.mean()-(x.mean()*x.mean());
478  double v2 = yy.mean()-(y.mean()*y.mean());
479 
480  double v3 = v1*v2;
481 
482  if (v3 <= 0)
483  // happens when there's very little variation in x
484  cor = 0;
485  else
486  cor = (xy.mean() - (x.mean()*y.mean()))/ sqrt(v3);
487 
488  if (output != NULL)
489  {
490  if (output != &cout) // save in output file
491  *output
492  << ";; RMSE " << ftoString(sqrt(se.mean()),4,1)
493  << " Correlation is " << ftoString(cor,4,1)
494  << " Mean (abs) Error " << ftoString(e.mean(),4,1)
495  << " (" << ftoString(e.stddev(),4,1) << ")" << endl;
496 
497  cout << "RMSE " << ftoString(sqrt(se.mean()),4,1)
498  << " Correlation is " << ftoString(cor,4,1)
499  << " Mean (abs) Error " << ftoString(e.mean(),4,1)
500  << " (" << ftoString(e.stddev(),4,1) << ")" << endl;
501  }
502 
503  if (wgn_opt_param == "rmse")
504  return -sqrt(se.mean()); // * -1 so bigger is better
505  else
506  return cor; // should really be % variance, I think
507 }
508 
509 static float test_tree_cluster(WNode &tree,WDataSet &dataset,ostream *output)
510 {
511  // Test tree against data to get summary of results for cluster trees
512  WNode *leaf;
513  int real;
514  int right_cluster=0;
515  EST_SuffStats ranking, meandist;
516  EST_Litem *p;
517 
518  for (p=dataset.head(); p != 0; p=p->next())
519  {
520  leaf = tree.predict_node((*dataset(p)));
521  real = dataset(p)->get_int_val(wgn_predictee);
522  meandist += leaf->get_impurity().cluster_distance(real);
523  right_cluster += leaf->get_impurity().in_cluster(real);
524  ranking += leaf->get_impurity().cluster_ranking(real);
525  }
526 
527  if (output != NULL)
528  {
529  // Want number in right class, mean distance in sds, mean ranking
530  if (output != &cout) // save in output file
531  *output << ";; Right cluster " << right_cluster << " (" <<
532  (int)(100.0*(float)right_cluster/(float)dataset.length()) <<
533  "%) mean ranking " << ranking.mean() << " mean distance "
534  << meandist.mean() << endl;
535  cout << "Right cluster " << right_cluster << " (" <<
536  (int)(100.0*(float)right_cluster/(float)dataset.length()) <<
537  "%) mean ranking " << ranking.mean() << " mean distance "
538  << meandist.mean() << endl;
539  }
540 
541  return 10000-meandist.mean(); // this doesn't work but I tested it
542 }
543 
544 static float test_tree_float(WNode &tree,WDataSet &dataset,ostream *output)
545 {
546  // Test tree against data to get summary of results FLOAT
547  EST_Litem *p;
548  float predict,real;
549  EST_SuffStats x,y,xx,yy,xy,se,e;
550  double cor,error;
551  double count;
552 
553  for (p=dataset.head(); p != 0; p=p->next())
554  {
555  predict = tree.predict((*dataset(p)));
556  real = dataset(p)->get_flt_val(wgn_predictee);
557  if (wgn_count_field == -1)
558  count = 1.0;
559  else
560  count = dataset(p)->get_flt_val(wgn_count_field);
561  x.cumulate(predict,count);
562  y.cumulate(real,count);
563  error = predict-real;
564  se.cumulate((error*error),count);
565  e.cumulate(fabs(error),count);
566  xx.cumulate(predict*predict,count);
567  yy.cumulate(real*real,count);
568  xy.cumulate(predict*real,count);
569  }
570 
571  // Pearson's product moment correlation coefficient
572 // cor = (xy.mean() - (x.mean()*y.mean()))/
573 // (sqrt(xx.mean()-(x.mean()*x.mean())) *
574 // sqrt(yy.mean()-(y.mean()*y.mean())));
575  // Because when the variation is X is very small we can
576  // go negative, thus cause the sqrt's to give FPE
577  double v1 = xx.mean()-(x.mean()*x.mean());
578  double v2 = yy.mean()-(y.mean()*y.mean());
579 
580  double v3 = v1*v2;
581 
582  if (v3 <= 0)
583  // happens when there's very little variation in x
584  cor = 0;
585  else
586  cor = (xy.mean() - (x.mean()*y.mean()))/ sqrt(v3);
587 
588  if (output != NULL)
589  {
590  if (output != &cout) // save in output file
591  *output
592  << ";; RMSE " << ftoString(sqrt(se.mean()),4,1)
593  << " Correlation is " << ftoString(cor,4,1)
594  << " Mean (abs) Error " << ftoString(e.mean(),4,1)
595  << " (" << ftoString(e.stddev(),4,1) << ")" << endl;
596 
597  cout << "RMSE " << ftoString(sqrt(se.mean()),4,1)
598  << " Correlation is " << ftoString(cor,4,1)
599  << " Mean (abs) Error " << ftoString(e.mean(),4,1)
600  << " (" << ftoString(e.stddev(),4,1) << ")" << endl;
601  }
602 
603  if (wgn_opt_param == "rmse")
604  return -sqrt(se.mean()); // * -1 so bigger is better
605  else
606  return cor; // should really be % variance, I think
607 }
608 
609 static int wagon_split(int margin, WNode &node)
610 {
611  // Split given node (if possible)
612  WQuestion q;
613  WNode *l,*r;
614 
615  node.set_impurity(WImpurity(node.get_data()));
616  q = find_best_question(node.get_data());
617 
618 /* printf("q.score() %f impurity %f\n",
619  q.get_score(),
620  node.get_impurity().measure()); */
621 
622  double impurity_measure = node.get_impurity().measure();
623  double question_score = q.get_score();
624 
625  if ((question_score < WGN_HUGE_VAL) &&
626  (question_score < impurity_measure))
627 
628  {
629  // Ok its worth a split
630  l = new WNode();
631  r = new WNode();
632  wgn_find_split(q,node.get_data(),l->get_data(),r->get_data());
633  node.set_subnodes(l,r);
634  node.set_question(q);
635  if (wgn_verbose)
636  {
637  int i;
638  for (i=0; i < margin; i++)
639  cout << " ";
640  cout << q << endl;
641  }
642  margin++;
643  wagon_split(margin,*l);
644  margin++;
645  wagon_split(margin,*r);
646  margin--;
647  return TRUE;
648  }
649  else
650  {
651  if (wgn_verbose)
652  {
653  int i;
654  for (i=0; i < margin; i++)
655  cout << " ";
656  cout << "stopped samples: " << node.samples() << " impurity: "
657  << node.get_impurity() << endl;
658  }
659  margin--;
660  return FALSE;
661  }
662 }
663 
664 void wgn_find_split(WQuestion &q,WVectorVector &ds,
666 {
667  int i, iy, in;
668 
669  y.resize(q.get_yes());
670  n.resize(q.get_no());
671 
672  for (iy=in=i=0; i < ds.n(); i++)
673  if (q.ask(*ds(i)) == TRUE)
674  y[iy++] = ds(i);
675  else
676  n[in++] = ds(i);
677 
678 }
679 
680 static WQuestion find_best_question(WVectorVector &dset)
681 {
682  // Ask all possible questions and find the best one
683  int i;
684  float bscore,tscore;
685  WQuestion test_ques, best_ques;
686 
687  bscore = tscore = WGN_HUGE_VAL;
688  best_ques.set_score(bscore);
689  // test each feature with each possible question
690  for (i=0;i < wgn_dataset.width(); i++)
691  {
692  if ((wgn_dataset.ignore(i) == TRUE) ||
693  (i == wgn_predictee))
694  tscore = WGN_HUGE_VAL; // ignore this feature this time
695  else if (wgn_dataset.ftype(i) == wndt_binary)
696  {
697  construct_binary_ques(i,test_ques);
698  tscore = wgn_score_question(test_ques,dset);
699  }
700  else if (wgn_dataset.ftype(i) == wndt_float)
701  {
702  tscore = construct_float_ques(i,test_ques,dset);
703  }
704  else if (wgn_dataset.ftype(i) == wndt_ignore)
705  tscore = WGN_HUGE_VAL; // always ignore this feature
706 #if 0
707  // This doesn't work reasonably
708  else if (wgn_csubset && (wgn_dataset.ftype(i) >= wndt_class))
709  {
710  wagon_error("subset selection temporarily deleted");
711  tscore = construct_class_ques_subset(i,test_ques,dset);
712  }
713 #endif
714  else if (wgn_dataset.ftype(i) >= wndt_class)
715  tscore = construct_class_ques(i,test_ques,dset);
716  if (tscore < bscore)
717  {
718  best_ques = test_ques;
719  best_ques.set_score(tscore);
720  bscore = tscore;
721  }
722  }
723 
724  return best_ques;
725 }
726 
727 static float construct_class_ques(int feat,WQuestion &ques,WVectorVector &ds)
728 {
729  // Find out which member of a class gives the best split
730  float tscore,bscore = WGN_HUGE_VAL;
731  int cl;
732  WQuestion test_q;
733 
734  test_q.set_fp(feat);
735  test_q.set_oper(wnop_is);
736  ques = test_q;
737 
738  for (cl=0; cl < wgn_discretes[wgn_dataset.ftype(feat)].length(); cl++)
739  {
740  test_q.set_operand1(EST_Val(cl));
741  tscore = wgn_score_question(test_q,ds);
742  if (tscore < bscore)
743  {
744  ques = test_q;
745  bscore = tscore;
746  }
747  }
748 
749  return bscore;
750 }
751 
752 #if 0
753 static float construct_class_ques_subset(int feat,WQuestion &ques,
754  WVectorVector &ds)
755 {
756  // Find out which subset of a class gives the best split.
757  // We first measure the subset of the data for each member of
758  // of the class. Then order those splits. Then go through finding
759  // where the best split of that ordered list is. This is described
760  // on page 247 of Breiman et al.
761  float tscore,bscore = WGN_HUGE_VAL;
762  LISP l;
763  int cl;
764 
765  ques.set_fp(feat);
766  ques.set_oper(wnop_is);
767  float *scores = new float[wgn_discretes[wgn_dataset.ftype(feat)].length()];
768 
769  // Only do it for exists values
770  for (cl=0; cl < wgn_discretes[wgn_dataset.ftype(feat)].length(); cl++)
771  {
772  ques.set_operand(flocons(cl));
773  scores[cl] = wgn_score_question(ques,ds);
774  }
775 
776  LISP order = sort_class_scores(feat,scores);
777  if (order == NIL)
778  return WGN_HUGE_VAL;
779  if (siod_llength(order) == 1)
780  { // Only one so we know the best "split"
781  ques.set_oper(wnop_is);
782  ques.set_operand(car(order));
783  return scores[get_c_int(car(order))];
784  }
785 
786  ques.set_oper(wnop_in);
787  LISP best_l = NIL;
788  for (l=cdr(order); CDR(l) != NIL; l = cdr(l))
789  {
790  ques.set_operand(l);
791  tscore = wgn_score_question(ques,ds);
792  if (tscore < bscore)
793  {
794  best_l = l;
795  bscore = tscore;
796  }
797 
798  }
799 
800  if (best_l != NIL)
801  {
802  if (siod_llength(best_l) == 1)
803  {
804  ques.set_oper(wnop_is);
805  ques.set_operand(car(best_l));
806  }
807  else if (equal(cdr(order),best_l) != NIL)
808  {
809  ques.set_oper(wnop_is);
810  ques.set_operand(car(order));
811  }
812  else
813  {
814  cout << "Found a good subset" << endl;
815  ques.set_operand(best_l);
816  }
817  }
818  return bscore;
819 }
820 
821 static LISP sort_class_scores(int feat,float *scores)
822 {
823  // returns sorted list of (non WGN_HUGE_VAL) items
824  int i;
825  LISP items = NIL;
826  LISP l;
827 
828  for (i=0; i < wgn_discretes[wgn_dataset.ftype(feat)].length(); i++)
829  {
830  if (scores[i] != WGN_HUGE_VAL)
831  {
832  if (items == NIL)
833  items = cons(flocons(i),NIL);
834  else
835  {
836  for (l=items; l != NIL; l=cdr(l))
837  {
838  if (scores[i] < scores[get_c_int(car(l))])
839  {
840  CDR(l) = cons(car(l),cdr(l));
841  CAR(l) = flocons(i);
842  break;
843  }
844  }
845  if (l == NIL)
846  items = l_append(items,cons(flocons(i),NIL));
847  }
848  }
849  }
850  return items;
851 }
852 #endif
853 
854 static float construct_float_ques(int feat,WQuestion &ques,WVectorVector &ds)
855 {
856  // Find out a split of the range that gives the best score
857  // Naively does this by partitioning the range into float_range_split slots
858  float tscore,bscore = WGN_HUGE_VAL;
859  int d, i;
860  float p;
861  WQuestion test_q;
862  float max,min,val,incr;
863 
864  test_q.set_fp(feat);
865  test_q.set_oper(wnop_lessthan);
866  ques = test_q;
867 
868  min = max = ds(0)->get_flt_val(feat); /* set up some value */
869  for (d=0; d < ds.n(); d++)
870  {
871  val = ds(d)->get_flt_val(feat);
872  if (val < min)
873  min = val;
874  else if (val > max)
875  max = val;
876  }
877  if (max == min) // we're pure
878  return WGN_HUGE_VAL;
879  incr = (max-min)/wgn_float_range_split;
880  // so do float_range-1 splits
881  /* We calculate this based on the number splits, not the increments, */
882  /* becuase incr can be so small it doesn't increment p */
883  for (i=0,p=min+incr; i < wgn_float_range_split; i++,p += incr )
884  {
885  test_q.set_operand1(EST_Val(p));
886  tscore = wgn_score_question(test_q,ds);
887  if (tscore < bscore)
888  {
889  ques = test_q;
890  bscore = tscore;
891  }
892  }
893 
894  return bscore;
895 }
896 
897 static void construct_binary_ques(int feat,WQuestion &test_ques)
898 {
899  // construct a question. Not sure about this in general
900  // of course continuous/categorical features will require different
901  // rule and non-binary ones will require some test point
902 
903  test_ques.set_fp(feat);
904  test_ques.set_oper(wnop_binary);
905  test_ques.set_operand1(EST_Val(""));
906 }
907 
908 static float score_question_set(WQuestion &q, WVectorVector &ds, int ignorenth)
909 {
910  // score this question as a possible split by finding
911  // the sum of the impurities when ds is split with this question
912  WImpurity y,n;
913  int d, num_yes, num_no;
914  float count;
915  WVector *wv;
916 
917  num_yes = num_no = 0;
918  for (d=0; d < ds.n(); d++)
919  {
920  if ((ignorenth < 2) ||
921  (d%ignorenth != ignorenth-1))
922  {
923  wv = ds(d);
924  if (wgn_count_field == -1)
925  count = 1.0;
926  else
927  count = (*wv)[wgn_count_field];
928 
929  if (q.ask(*wv) == TRUE)
930  {
931  num_yes++;
932  y.cumulate((*wv)[wgn_predictee],count);
933  }
934  else
935  {
936  num_no++;
937  n.cumulate((*wv)[wgn_predictee],count);
938  }
939  }
940  }
941 
942  q.set_yes(num_yes);
943  q.set_no(num_no);
944 
945  int min_cluster;
946 
947  if ((wgn_balance == 0.0) ||
948  (ds.n()/wgn_balance < wgn_min_cluster_size))
949  min_cluster = wgn_min_cluster_size;
950  else
951  min_cluster = (int)(ds.n()/wgn_balance);
952 
953  if ((y.samples() < min_cluster) ||
954  (n.samples() < min_cluster))
955  return WGN_HUGE_VAL;
956 
957  float ym,nm,bm;
958  ym = y.measure();
959  nm = n.measure();
960  bm = ym + nm;
961 
962 /* cout << q << endl;
963  printf("test question y %f n %f b %f\n",
964  ym, nm, bm); */
965 
966  return bm/2.0;
967 }
968 
969 float wgn_score_question(WQuestion &q, WVectorVector &ds)
970 {
971  // This level of indirection was introduced for later expansion
972 
973  return score_question_set(q,ds,1);
974 }
975 
976 WNode *wagon_stepwise(float limit)
977 {
978  // Find the best single features and incrementally add features
979  // that best improve result until it doesn't improve.
980  // This is basically to automate what Kurt was doing in building
981  // trees, he then automated it in PERL and as it seemed to work
982  // I put it into wagon itself.
983  // This can be pretty computationally intensive.
984  WNode *best = 0,*new_best = 0;
985  float bscore,best_score = -WGN_HUGE_VAL;
986  int best_feat,i;
987  int nf = 1;
988 
989  // Set all features to ignore
990  for (i=0; i < wgn_dataset.width(); i++)
991  wgn_dataset.set_ignore(i,TRUE);
992 
993  for (i=0; i < wgn_dataset.width(); i++)
994  {
995  if ((wgn_dataset.ftype(i) == wndt_ignore) || (i == wgn_predictee))
996  {
997  // This skips the round not because this has anything to
998  // do with this feature being (user specified) ignored
999  // but because it indicates there is one less cycle that is
1000  // necessary
1001  continue;
1002  }
1003  new_best = wagon_stepwise_find_next_best(bscore,best_feat);
1004 
1005  if ((bscore - fabs(bscore * (limit/100))) <= best_score)
1006  {
1007  // gone as far as we can
1008  delete new_best;
1009  break;
1010  }
1011  else
1012  {
1013  best_score = bscore;
1014  delete best;
1015  best = new_best;
1016  wgn_dataset.set_ignore(best_feat,FALSE);
1017  if (!wgn_quiet)
1018  {
1019  fprintf(stdout,"FEATURE %d %s: %2.4f\n",
1020  nf,
1021  (const char *)wgn_dataset.feat_name(best_feat),
1022  best_score);
1023  fflush(stdout);
1024  nf++;
1025  }
1026  }
1027  }
1028 
1029  return best;
1030 }
1031 
1032 static WNode *wagon_stepwise_find_next_best(float &bscore,int &best_feat)
1033 {
1034  // Find which of the currently ignored features will best improve
1035  // the result
1036  WNode *best = 0;
1037  float best_score = -WGN_HUGE_VAL;
1038  int best_new_feat = -1;
1039  int i;
1040 
1041  for (i=0; i < wgn_dataset.width(); i++)
1042  {
1043  if (wgn_dataset.ftype(i) == wndt_ignore)
1044  continue; // user wants me to ignore this completely
1045  else if (i == wgn_predictee) // can't use the answer
1046  continue;
1047  else if (wgn_dataset.ignore(i) == TRUE)
1048  {
1049  WNode *current;
1050  float score;
1051 
1052  // Allow this feature to participate
1053  wgn_dataset.set_ignore(i,FALSE);
1054 
1055  current = wgn_build_tree(score);
1056 
1057  if (score > best_score)
1058  {
1059  best_score = score;
1060  delete best;
1061  best = current;
1062  best_new_feat = i;
1063 // fprintf(stdout,"BETTER FEATURE %d %s: %2.4f\n",
1064 // i,
1065 // (const char *)wgn_dataset.feat_name(i),
1066 // best_score);
1067 // fflush(stdout);
1068  }
1069  else
1070  delete current;
1071 
1072  // switch it off again
1073  wgn_dataset.set_ignore(i,TRUE);
1074  }
1075  }
1076 
1077  bscore = best_score;
1078  best_feat = best_new_feat;
1079  return best;
1080 }
EST_TokenStream & get(EST_Token &t)
get next token in stream
Definition: EST_Token.cc:486
double stddev(void) const
standard deviation of currently cummulated values
float & a(int i, int c=0)
Definition: EST_Track.cc:1022
int num_channels() const
return number of channels in track
Definition: EST_Track.h:656
void set_SingleCharSymbols(const EST_String &sc)
set which characters are to be treated as single character symbols
Definition: EST_Token.h:345
double mean(void) const
mean of currently cummulated values
void close(void)
Close stream.
Definition: EST_Token.cc:406
const EST_String & name(const int n) const
The name given the index.
EST_String itoString(int n)
Make a EST_String object from an integer.
Definition: util_io.cc:140
EST_String ftoString(float n, int pres=3, int width=0, int l=0)
Make a EST_String object from an float, with variable precision.
Definition: util_io.cc:148
void set_PrePunctuationSymbols(const EST_String &ps)
set which characters are to be treated as (post) punctuation
Definition: EST_Token.h:351
int open(const EST_String &filename)
open a EST_TokenStream for a file.
Definition: EST_Token.cc:200
void set_PunctuationSymbols(const EST_String &ps)
set which characters are to be treated as (post) punctuation
Definition: EST_Token.h:348
void resize(int n, int set=1)
Definition: EST_TVector.cc:196
int eof()
end of file
Definition: EST_Token.h:363
EST_Token & peek(void)
peek at next token
Definition: EST_Token.h:333
INLINE int n() const
number of items in vector.
Definition: EST_TVector.h:252
void reset(void)
reset internal values
int add_item(const K &rkey, const V &rval, int no_search=0)
add key-val pair to list
Definition: EST_TKVL.cc:248
void append(const T &item)
add item onto end of list
Definition: EST_TList.h:198
int eoln()
end of line
Definition: EST_Token.cc:818
T & item(const EST_Litem *p)
Definition: EST_TList.h:141