49 #include "EST_Token.h"
50 #include "EST_FMatrix.h"
51 #include "EST_multistats.h"
52 #include "EST_Wagon.h"
64 int wgn_min_cluster_size = 50;
67 int wgn_quiet = FALSE;
68 int wgn_verbose = FALSE;
69 int wgn_count_field = -1;
71 int wgn_predictee = 0;
73 float wgn_float_range_split = 10;
74 float wgn_balance = 0;
79 static float do_summary(
WNode &tree,
WDataSet &ds,ostream *output);
80 static float test_tree_float(
WNode &tree,
WDataSet &ds,ostream *output);
81 static float test_tree_class(
WNode &tree,
WDataSet &ds,ostream *output);
82 static float test_tree_cluster(
WNode &tree,
WDataSet &dataset, ostream *output);
83 static float test_tree_vector(
WNode &tree,
WDataSet &dataset,ostream *output);
84 static float test_tree_trajectory(
WNode &tree,
WDataSet &dataset,ostream *output);
85 static int wagon_split(
int margin,
WNode &node);
87 static void construct_binary_ques(
int feat,
WQuestion &test_ques);
91 static WNode *wagon_stepwise_find_next_best(
float &bscore,
int &best_feat);
93 Declare_TList_T(
WVector *, WVectorP)
95 Declare_TVector_Base_T(
WVector *,NULL,NULL,WVectorP)
97 #if defined(INSTANTIATE_TEMPLATES)
99 #include "../base_class/EST_TList.cc"
100 #include "../base_class/EST_TVector.cc"
102 Instantiate_TList_T(WVector *, WVectorP)
104 Instantiate_TVector(WVector *)
108 void wgn_load_datadescription(
EST_String fname,LISP ignores)
111 wgn_dataset.load_description(fname,ignores);
112 wgn_test_dataset.load_description(fname,ignores);
123 if (ts.
open(fname) == -1)
124 wagon_error(
EST_String(
"unable to open data file \"")+
132 v =
new WVector(dataset.width());
136 int type = dataset.ftype(i);
137 if ((type == wndt_float) || (wgn_count_field == i))
140 float f = atof(ts.
get().string());
145 cout << fname <<
": bad float " << f
147 dataset.feat_name(i) <<
" vector " <<
148 dataset.samples() << endl;
149 v->set_flt_val(i,0.0);
152 else if (type == wndt_binary)
153 v->set_int_val(i,atoi(ts.
get().string()));
154 else if (type == wndt_cluster)
155 v->set_int_val(i,atoi(ts.
get().string()));
156 else if (type == wndt_vector)
157 v->set_int_val(i,atoi(ts.
get().string()));
158 else if (type == wndt_trajectory)
164 v->set_int_val(i,atoi(ts.
get().string()));
166 else if (type == wndt_ignore)
174 int n = wgn_discretes.discrete(type).
name(s);
177 cout << fname <<
": bad value " << s <<
" in field " <<
178 dataset.feat_name(i) <<
" vector " <<
179 dataset.samples() << endl;
186 while (!ts.
eoln() && i<dataset.width());
188 if (i != dataset.width())
190 wagon_error(fname+
": data vector "+
itoString(nvec)+
" contains "
196 cerr << fname <<
": data vector " << nvec <<
197 " contains too many parameters instead of "
198 << dataset.width() << endl;
199 wagon_error(
EST_String(
"extra parameter(s) from ")+
205 cout <<
"Dataset of " << dataset.samples() <<
" vectors of " <<
206 dataset.width() <<
" parameters from: " << fname << endl;
210 float summary_results(
WNode &tree,ostream *output)
212 if (wgn_test_dataset.samples() != 0)
213 return do_summary(tree,wgn_test_dataset,output);
215 return do_summary(tree,wgn_dataset,output);
218 static float do_summary(
WNode &tree,
WDataSet &ds,ostream *output)
220 if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
221 return test_tree_cluster(tree,ds,output);
222 else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
223 return test_tree_vector(tree,ds,output);
224 else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
225 return test_tree_trajectory(tree,ds,output);
226 else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
227 return test_tree_class(tree,ds,output);
229 return test_tree_float(tree,ds,output);
232 WNode *wgn_build_tree(
float &score)
238 wgn_set_up_data(top->get_data(),wgn_dataset,wgn_held_out,TRUE);
241 wagon_split(margin,*top);
243 if (wgn_held_out > 0)
245 wgn_set_up_data(top->get_data(),wgn_dataset,wgn_held_out,FALSE);
246 top->held_out_prune();
252 score = summary_results(*top,0);
267 for (j=i=0,d=ds.head(); d != 0; d=d->next(),j++)
269 if ((in) && ((j%100) >= held_out))
284 static float test_tree_class(
WNode &tree,
WDataSet &dataset,ostream *output)
294 float correct=0,total=0, count=0;
296 for (p=dataset.head(); p != 0; p=p->next())
298 pnode = tree.predict_node((*dataset(p)));
299 predict = (
EST_String)pnode->get_impurity().value();
300 if (wgn_count_field == -1)
303 count = dataset(p)->get_flt_val(wgn_count_field);
304 prob = pnode->get_impurity().pd().probability(predict);
305 H += (log(prob))*count;
306 type = dataset.ftype(wgn_predictee);
307 real = wgn_discretes[type].name(dataset(p)->get_int_val(wgn_predictee));
313 for (i=0; i<wgn_discretes[dataset.ftype(wgn_predictee)].length(); i++)
314 lex.
append(wgn_discretes[dataset.ftype(wgn_predictee)].name(i));
320 print_confusion(m,pairs,lex);
321 *output <<
";; entropy " << (-1*(H/total)) <<
" perplexity " <<
322 pow(2.0,(-1*(H/total))) << endl;
326 if (wgn_opt_param ==
"entropy")
327 return -pow(2.0,(-1*(H/total)));
329 return (
float)correct/(float)total;
332 static float test_tree_vector(
WNode &tree,
WDataSet &dataset,ostream *output)
339 float predict, actual;
347 for (p=dataset.head(); p != 0; p=p->next())
349 leaf = tree.predict_node((*dataset(p)));
350 pos = dataset(p)->get_int_val(wgn_predictee);
352 if (wgn_VertexFeats.
a(0,j) > 0.0)
355 for (pp=leaf->get_impurity().members.head(); pp != 0; pp=pp->next())
357 i = leaf->get_impurity().members.
item(pp);
358 b += wgn_VertexTrack.
a(i,j);
361 actual = wgn_VertexTrack.
a(pos,j);
362 if (wgn_count_field == -1)
365 count = dataset(p)->get_flt_val(wgn_count_field);
366 x.cumulate(predict,count);
367 y.cumulate(actual,count);
370 error = predict-actual;
372 error = (predict-actual)/b.
stddev();
373 error = predict-actual;
374 se.cumulate((error*error),count);
375 e.cumulate(fabs(error),count);
376 xx.cumulate(predict*predict,count);
377 yy.cumulate(actual*actual,count);
378 xy.cumulate(predict*actual,count);
404 <<
" Correlation is " <<
ftoString(cor,4,1)
409 <<
" Correlation is " <<
ftoString(cor,4,1)
414 if (wgn_opt_param ==
"rmse")
415 return -sqrt(se.
mean());
420 static float test_tree_trajectory(
WNode &tree,
WDataSet &dataset,ostream *output)
428 float predict, actual;
436 for (p=dataset.head(); p != 0; p=p->next())
438 leaf = tree.predict_node((*dataset(p)));
439 pos = dataset(p)->get_int_val(wgn_predictee);
441 if (wgn_VertexFeats.
a(0,j) > 0.0)
444 for (pp=leaf->get_impurity().members.head(); pp != 0; pp=pp->next())
446 i = leaf->get_impurity().members.
item(pp);
447 b += wgn_VertexTrack.
a(i,j);
450 actual = wgn_VertexTrack.
a(pos,j);
451 if (wgn_count_field == -1)
454 count = dataset(p)->get_flt_val(wgn_count_field);
455 x.cumulate(predict,count);
456 y.cumulate(actual,count);
459 error = predict-actual;
461 error = (predict-actual)/b.
stddev();
462 error = predict-actual;
463 se.cumulate((error*error),count);
464 e.cumulate(fabs(error),count);
465 xx.cumulate(predict*predict,count);
466 yy.cumulate(actual*actual,count);
467 xy.cumulate(predict*actual,count);
493 <<
" Correlation is " <<
ftoString(cor,4,1)
498 <<
" Correlation is " <<
ftoString(cor,4,1)
503 if (wgn_opt_param ==
"rmse")
504 return -sqrt(se.
mean());
509 static float test_tree_cluster(
WNode &tree,
WDataSet &dataset,ostream *output)
518 for (p=dataset.head(); p != 0; p=p->next())
520 leaf = tree.predict_node((*dataset(p)));
521 real = dataset(p)->get_int_val(wgn_predictee);
522 meandist += leaf->get_impurity().cluster_distance(real);
523 right_cluster += leaf->get_impurity().in_cluster(real);
524 ranking += leaf->get_impurity().cluster_ranking(real);
531 *output <<
";; Right cluster " << right_cluster <<
" (" <<
532 (int)(100.0*(
float)right_cluster/(
float)dataset.length()) <<
533 "%) mean ranking " << ranking.mean() << " mean distance "
534 << meandist.mean() << endl;
535 cout << "Right cluster " << right_cluster << " (" <<
536 (
int)(100.0*(
float)right_cluster/(
float)dataset.length()) <<
537 "%) mean ranking " << ranking.mean() << " mean distance "
538 << meandist.mean() << endl;
541 return 10000-meandist.mean();
544 static
float test_tree_float(
WNode &tree,
WDataSet &dataset,ostream *output)
553 for (p=dataset.head(); p != 0; p=p->next())
555 predict = tree.predict((*dataset(p)));
556 real = dataset(p)->get_flt_val(wgn_predictee);
557 if (wgn_count_field == -1)
560 count = dataset(p)->get_flt_val(wgn_count_field);
561 x.cumulate(predict,count);
562 y.cumulate(real,count);
563 error = predict-real;
564 se.cumulate((error*error),count);
565 e.cumulate(fabs(error),count);
566 xx.cumulate(predict*predict,count);
567 yy.cumulate(real*real,count);
568 xy.cumulate(predict*real,count);
593 <<
" Correlation is " <<
ftoString(cor,4,1)
598 <<
" Correlation is " <<
ftoString(cor,4,1)
603 if (wgn_opt_param ==
"rmse")
604 return -sqrt(se.
mean());
609 static int wagon_split(
int margin,
WNode &node)
615 node.set_impurity(
WImpurity(node.get_data()));
616 q = find_best_question(node.get_data());
622 double impurity_measure = node.get_impurity().measure();
623 double question_score = q.get_score();
625 if ((question_score < WGN_HUGE_VAL) &&
626 (question_score < impurity_measure))
632 wgn_find_split(q,node.get_data(),l->get_data(),r->get_data());
633 node.set_subnodes(l,r);
634 node.set_question(q);
638 for (i=0; i < margin; i++)
643 wagon_split(margin,*l);
645 wagon_split(margin,*r);
654 for (i=0; i < margin; i++)
656 cout <<
"stopped samples: " << node.samples() <<
" impurity: "
657 << node.get_impurity() << endl;
672 for (iy=in=i=0; i < ds.
n(); i++)
673 if (q.ask(*ds(i)) == TRUE)
687 bscore = tscore = WGN_HUGE_VAL;
688 best_ques.set_score(bscore);
690 for (i=0;i < wgn_dataset.width(); i++)
692 if ((wgn_dataset.ignore(i) == TRUE) ||
693 (i == wgn_predictee))
694 tscore = WGN_HUGE_VAL;
695 else if (wgn_dataset.ftype(i) == wndt_binary)
697 construct_binary_ques(i,test_ques);
698 tscore = wgn_score_question(test_ques,dset);
700 else if (wgn_dataset.ftype(i) == wndt_float)
702 tscore = construct_float_ques(i,test_ques,dset);
704 else if (wgn_dataset.ftype(i) == wndt_ignore)
705 tscore = WGN_HUGE_VAL;
708 else if (wgn_csubset && (wgn_dataset.ftype(i) >= wndt_class))
710 wagon_error(
"subset selection temporarily deleted");
711 tscore = construct_class_ques_subset(i,test_ques,dset);
714 else if (wgn_dataset.ftype(i) >= wndt_class)
715 tscore = construct_class_ques(i,test_ques,dset);
718 best_ques = test_ques;
719 best_ques.set_score(tscore);
730 float tscore,bscore = WGN_HUGE_VAL;
735 test_q.set_oper(wnop_is);
738 for (cl=0; cl < wgn_discretes[wgn_dataset.ftype(feat)].length(); cl++)
740 test_q.set_operand1(
EST_Val(cl));
741 tscore = wgn_score_question(test_q,ds);
753 static float construct_class_ques_subset(
int feat,
WQuestion &ques,
761 float tscore,bscore = WGN_HUGE_VAL;
766 ques.set_oper(wnop_is);
767 float *scores =
new float[wgn_discretes[wgn_dataset.ftype(feat)].length()];
770 for (cl=0; cl < wgn_discretes[wgn_dataset.ftype(feat)].length(); cl++)
772 ques.set_operand(flocons(cl));
773 scores[cl] = wgn_score_question(ques,ds);
776 LISP order = sort_class_scores(feat,scores);
779 if (siod_llength(order) == 1)
781 ques.set_oper(wnop_is);
782 ques.set_operand(car(order));
783 return scores[get_c_int(car(order))];
786 ques.set_oper(wnop_in);
788 for (l=cdr(order); CDR(l) != NIL; l = cdr(l))
791 tscore = wgn_score_question(ques,ds);
802 if (siod_llength(best_l) == 1)
804 ques.set_oper(wnop_is);
805 ques.set_operand(car(best_l));
807 else if (equal(cdr(order),best_l) != NIL)
809 ques.set_oper(wnop_is);
810 ques.set_operand(car(order));
814 cout <<
"Found a good subset" << endl;
815 ques.set_operand(best_l);
821 static LISP sort_class_scores(
int feat,
float *scores)
828 for (i=0; i < wgn_discretes[wgn_dataset.ftype(feat)].length(); i++)
830 if (scores[i] != WGN_HUGE_VAL)
833 items = cons(flocons(i),NIL);
836 for (l=items; l != NIL; l=cdr(l))
838 if (scores[i] < scores[get_c_int(car(l))])
840 CDR(l) = cons(car(l),cdr(l));
846 items = l_append(items,cons(flocons(i),NIL));
858 float tscore,bscore = WGN_HUGE_VAL;
862 float max,min,val,incr;
865 test_q.set_oper(wnop_lessthan);
868 min = max = ds(0)->get_flt_val(feat);
869 for (d=0; d < ds.
n(); d++)
871 val = ds(d)->get_flt_val(feat);
879 incr = (max-min)/wgn_float_range_split;
883 for (i=0,p=min+incr; i < wgn_float_range_split; i++,p += incr )
885 test_q.set_operand1(
EST_Val(p));
886 tscore = wgn_score_question(test_q,ds);
897 static void construct_binary_ques(
int feat,
WQuestion &test_ques)
903 test_ques.set_fp(feat);
904 test_ques.set_oper(wnop_binary);
905 test_ques.set_operand1(
EST_Val(
""));
913 int d, num_yes, num_no;
917 num_yes = num_no = 0;
918 for (d=0; d < ds.
n(); d++)
920 if ((ignorenth < 2) ||
921 (d%ignorenth != ignorenth-1))
924 if (wgn_count_field == -1)
927 count = (*wv)[wgn_count_field];
929 if (q.ask(*wv) == TRUE)
932 y.cumulate((*wv)[wgn_predictee],count);
937 n.cumulate((*wv)[wgn_predictee],count);
947 if ((wgn_balance == 0.0) ||
948 (ds.
n()/wgn_balance < wgn_min_cluster_size))
949 min_cluster = wgn_min_cluster_size;
951 min_cluster = (int)(ds.
n()/wgn_balance);
953 if ((y.samples() < min_cluster) ||
954 (n.samples() < min_cluster))
973 return score_question_set(q,ds,1);
976 WNode *wagon_stepwise(
float limit)
984 WNode *best = 0,*new_best = 0;
985 float bscore,best_score = -WGN_HUGE_VAL;
990 for (i=0; i < wgn_dataset.width(); i++)
991 wgn_dataset.set_ignore(i,TRUE);
993 for (i=0; i < wgn_dataset.width(); i++)
995 if ((wgn_dataset.ftype(i) == wndt_ignore) || (i == wgn_predictee))
1003 new_best = wagon_stepwise_find_next_best(bscore,best_feat);
1005 if ((bscore - fabs(bscore * (limit/100))) <= best_score)
1013 best_score = bscore;
1016 wgn_dataset.set_ignore(best_feat,FALSE);
1019 fprintf(stdout,
"FEATURE %d %s: %2.4f\n",
1021 (
const char *)wgn_dataset.feat_name(best_feat),
1032 static WNode *wagon_stepwise_find_next_best(
float &bscore,
int &best_feat)
1037 float best_score = -WGN_HUGE_VAL;
1038 int best_new_feat = -1;
1041 for (i=0; i < wgn_dataset.width(); i++)
1043 if (wgn_dataset.ftype(i) == wndt_ignore)
1045 else if (i == wgn_predictee)
1047 else if (wgn_dataset.ignore(i) == TRUE)
1053 wgn_dataset.set_ignore(i,FALSE);
1055 current = wgn_build_tree(score);
1057 if (score > best_score)
1073 wgn_dataset.set_ignore(i,TRUE);
1077 bscore = best_score;
1078 best_feat = best_new_feat;
EST_TokenStream & get(EST_Token &t)
get next token in stream
double stddev(void) const
standard deviation of currently cummulated values
float & a(int i, int c=0)
int num_channels() const
return number of channels in track
void set_SingleCharSymbols(const EST_String &sc)
set which characters are to be treated as single character symbols
double mean(void) const
mean of currently cummulated values
void close(void)
Close stream.
const EST_String & name(const int n) const
The name given the index.
EST_String itoString(int n)
Make a EST_String object from an integer.
EST_String ftoString(float n, int pres=3, int width=0, int l=0)
Make a EST_String object from an float, with variable precision.
void set_PrePunctuationSymbols(const EST_String &ps)
set which characters are to be treated as (post) punctuation
int open(const EST_String &filename)
open a EST_TokenStream for a file.
void set_PunctuationSymbols(const EST_String &ps)
set which characters are to be treated as (post) punctuation
void resize(int n, int set=1)
EST_Token & peek(void)
peek at next token
INLINE int n() const
number of items in vector.
void reset(void)
reset internal values
int add_item(const K &rkey, const V &rval, int no_search=0)
add key-val pair to list
void append(const T &item)
add item onto end of list
T & item(const EST_Litem *p)