Edinburgh Speech Tools  2.1-release
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
wagon_test_main.cc
1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1996,1997 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : October 1997 */
35 /*-----------------------------------------------------------------------*/
36 /* A program for testing a CART tree against data, also may be used to */
37 /* predict values using a tree and data */
38 /* */
39 /*=======================================================================*/
40 #include <cstdlib>
41 #include <iostream>
42 #include <fstream>
43 #include <cstring>
44 #include "EST_Wagon.h"
45 #include "EST_cutils.h"
46 #include "EST_multistats.h"
47 #include "EST_Token.h"
48 #include "EST_cmd_line.h"
49 
50 static int wagon_test_main(int argc, char **argv);
51 static LISP find_feature_value(const char *feature,
52  LISP vector, LISP description);
53 static LISP wagon_vector_predict(LISP tree, LISP vector, LISP description);
54 static LISP get_data_vector(EST_TokenStream &data, LISP description);
55 static void simple_predict(EST_TokenStream &data, FILE *output,
56  LISP tree, LISP description, int all_info);
57 static void test_tree_class(EST_TokenStream &data, FILE *output,
58  LISP tree, LISP description);
59 static void test_tree_float(EST_TokenStream &data, FILE *output,
60  LISP tree, LISP description);
61 
62 
63 int main(int argc, char **argv)
64 {
65 
66  wagon_test_main(argc,argv);
67 
68  exit(0);
69  return 0;
70 }
71 
72 static int wagon_test_main(int argc, char **argv)
73 {
74  // Top level function sets up data and creates a tree
75  EST_Option al;
76  EST_StrList files;
77  LISP description,tree=NIL;;
78  EST_TokenStream data;
79  FILE *wgn_output;
80 
81  parse_command_line
82  (argc, argv,
83  EST_String("<options>\n")+
84  "Summary: program to test CART models on data\n"+
85  "-desc <ifile> Field description file\n"+
86  "-data <ifile> Datafile, one vector per line\n"+
87  "-tree <ifile> File containing CART tree\n"+
88  "-track <ifile>\n"+
89  " track for vertex indices\n"+
90  "-predict Predict for each vector returning full vector\n"+
91  "-predict_val Predict for each vector returning just value\n"+
92  "-predictee <string>\n"+
93  " name of field to predict (default is first field)\n"+
94  "-heap <int> {210000}\n"+
95  " Set size of Lisp heap, should not normally need\n"+
96  " to be changed from its default\n"+
97  "-o <ofile> File to save output in\n",
98  files, al);
99 
100  siod_init(al.ival("-heap"));
101 
102  if (al.present("-desc"))
103  {
104  gc_protect(&description);
105  description = car(vload(al.val("-desc"),1));
106  }
107  else
108  {
109  cerr << argv[0] << ": no description file specified" << endl;
110  exit(-1);
111  }
112 
113  if (al.present("-tree"))
114  {
115  gc_protect(&tree);
116  tree = car(vload(al.val("-tree"),1));
117  if (tree == NIL)
118  {
119  cerr << argv[0] << ": no tree found in \"" << al.val("-tree")
120  << "\"" << endl;
121  exit(-1);
122  }
123  }
124  else
125  {
126  cerr << argv[0] << ": no tree file specified" << endl;
127  exit(-1);
128  }
129 
130  if (al.present("-data"))
131  {
132  if (data.open(al.val("-data")) != 0)
133  {
134  cerr << argv[0] << ": can't open data file \"" <<
135  al.val("-data") << "\" for input." << endl;
136  exit(-1);
137  }
138  }
139  else
140  {
141  cerr << argv[0] << ": no data file specified" << endl;
142  exit(-1);
143  }
144 
145  if (al.present("-track"))
146  {
147  wgn_VertexTrack.load(al.val("-track"));
148  }
149 
150  if (al.present("-o"))
151  {
152  if ((wgn_output = fopen(al.val("-o"),"w")) == NULL)
153  {
154  cerr << argv[0] << ": can't open output file \"" <<
155  al.val("-o") << "\"" << endl;
156  }
157  }
158  else
159  wgn_output = stdout;
160 
161  if (al.present("-predictee"))
162  {
163  LISP l;
164  int i;
165  wgn_predictee_name = al.val("-predictee");
166  for (l=description,i=0; l != NIL; l=cdr(l),i++)
167  if (streq(wgn_predictee_name,get_c_string(car(car(l)))))
168  {
169  wgn_predictee = i;
170  break;
171  }
172  if (l==NIL)
173  {
174  cerr << argv[0] << ": predictee \"" << wgn_predictee <<
175  "\" not in description\n";
176  }
177  }
178  const char *predict_type =
179  get_c_string(car(cdr(siod_nth(wgn_predictee,description))));
180 
181  if (al.present("-predict"))
182  simple_predict(data,wgn_output,tree,description,FALSE);
183  else if (al.present("-predict_val"))
184  simple_predict(data,wgn_output,tree,description,TRUE);
185  else if (streq(predict_type,"float") ||
186  streq(predict_type,"int"))
187  test_tree_float(data,wgn_output,tree,description);
188 #if 0
189  else if (streq(predict_type,"vector"))
190  test_tree_vector(data,wgn_output,tree,description);
191 #endif
192  else
193  test_tree_class(data,wgn_output,tree,description);
194 
195  if (wgn_output != stdout)
196  fclose(wgn_output);
197  data.close();
198  return 0;
199 }
200 
201 static LISP get_data_vector(EST_TokenStream &data, LISP description)
202 {
203  // read in one vector. Should be terminated with an newline
204  LISP v=NIL,d;
205 
206  if (data.eof())
207  return NIL;
208 
209  for (d=description; d != NIL; d=cdr(d))
210  {
211  EST_Token t = data.get();
212 
213  if ((d != description) && (t.whitespace().contains("\n")))
214  {
215  cerr << "wagon_test: unexpected newline within vector " <<
216  t.row() << " wrong number of features" << endl;
217  siod_error();
218  }
219  if (streq(get_c_string(car(cdr(car(d)))),"float") ||
220  streq(get_c_string(car(cdr(car(d)))),"int"))
221  v = cons(flocons(atof(t.string())),v);
222  else if ((streq(get_c_string(car(cdr(car(d)))),"_other_")) &&
223  (siod_member_str(t.string(),cdr(car(d))) == NIL))
224  v = cons(strintern("_other_"),v);
225  else
226  v = cons(strintern(t.string()),v);
227  }
228 
229  return reverse(v);
230 }
231 
232 static void simple_predict(EST_TokenStream &data, FILE *output,
233  LISP tree, LISP description, int all_info)
234 {
235  LISP vector,predict;
236  EST_String val;
237 
238  for (vector=get_data_vector(data,description);
239  vector != NIL; vector=get_data_vector(data,description))
240  {
241  predict = wagon_vector_predict(tree,vector,description);
242  if (all_info)
243  val = siod_sprint(car(reverse(predict)));
244  else
245  val = siod_sprint(predict);
246  fprintf(output,"%s\n",(const char *)val);
247  }
248 }
249 
250 static void test_tree_float(EST_TokenStream &data, FILE *output,
251  LISP tree, LISP description)
252 {
253  // Test tree against data to get summary of results FLOAT
254  float predict_val,real_val;
255  EST_SuffStats x,y,xx,yy,xy,se,e;
256  double cor,error;
257  LISP vector,predict;
258 
259  for (vector=get_data_vector(data,description);
260  vector != NIL; vector=get_data_vector(data,description))
261  {
262  predict = wagon_vector_predict(tree,vector,description);
263  predict_val = get_c_float(car(reverse(predict)));
264  real_val = get_c_float(siod_nth(wgn_predictee,vector));
265  x += predict_val;
266  y += real_val;
267  error = predict_val-real_val;
268  se += error*error;
269  e += fabs(error);
270  xx += predict_val*predict_val;
271  yy += real_val*real_val;
272  xy += predict_val*real_val;
273  }
274 
275  cor = (xy.mean() - (x.mean()*y.mean()))/
276  (sqrt(xx.mean()-(x.mean()*x.mean())) *
277  sqrt(yy.mean()-(y.mean()*y.mean())));
278 
279  fprintf(output,";; RMSE %1.4f Correlation is %1.4f Mean (abs) Error %1.4f (%1.4f)\n",
280  sqrt(se.mean()),
281  cor,
282  e.mean(),
283  e.stddev());
284 }
285 
286 static void test_tree_class(EST_TokenStream &data, FILE *output,
287  LISP tree, LISP description)
288 {
289  // Test tree against class data to get summary of results
290  EST_StrStr_KVL pairs;
291  EST_StrList lex;
292  EST_String predict_class,real_class;
293  LISP vector,w,predict;
294  double H=0,Q=0,prob;
295  (void)output;
296 
297  for (vector=get_data_vector(data,description);
298  vector != NIL; vector=get_data_vector(data,description))
299  {
300  predict = wagon_vector_predict(tree,vector,description);
301  predict_class = get_c_string(car(reverse(predict)));
302  real_class = get_c_string(siod_nth(wgn_predictee,vector));
303  prob = get_c_float(car(cdr(siod_assoc_str(real_class,
304  predict))));
305  if (prob == 0)
306  H += log(0.000001);
307  else
308  H += log(prob);
309  Q ++;
310  pairs.add_item(real_class,predict_class,1);
311  }
312  for (w=cdr(siod_nth(wgn_predictee,description)); w != NIL; w = cdr(w))
313  lex.append(get_c_string(car(w)));
314 
315  const EST_FMatrix &m = confusion(pairs,lex);
316  print_confusion(m,pairs,lex);
317  fprintf(stdout,";; entropy %g perplexity %g\n",
318  (-1*(H/Q)),pow(2.0,(-1*(H/Q))));
319 }
320 
321 static void test_tree_vector(EST_TokenStream &data, FILE *output,
322  LISP tree, LISP description)
323 {
324  // Test tree against class data to get summary of results
325  // Note we are talking about predicting vectors (a *bunch* of
326  // numbers, not just a single class here)
327  EST_StrStr_KVL pairs;
328  EST_StrList lex;
329  EST_String predict_class,real_class;
330  LISP vector,w,predict;
331  double H=0,Q=0,prob;
332  (void)output;
333 
334  for (vector=get_data_vector(data,description);
335  vector != NIL; vector=get_data_vector(data,description))
336  {
337  predict = wagon_vector_predict(tree,vector,description);
338  predict_class = get_c_string(car(reverse(predict)));
339  real_class = get_c_string(siod_nth(wgn_predictee,vector));
340  prob = get_c_float(car(cdr(siod_assoc_str(real_class,
341  predict))));
342  if (prob == 0)
343  H += log(0.000001);
344  else
345  H += log(prob);
346  Q ++;
347  pairs.add_item(real_class,predict_class,1);
348  }
349  for (w=cdr(siod_nth(wgn_predictee,description)); w != NIL; w = cdr(w))
350  lex.append(get_c_string(car(w)));
351 
352  const EST_FMatrix &m = confusion(pairs,lex);
353  print_confusion(m,pairs,lex);
354  fprintf(stdout,";; entropy %g perplexity %g\n",
355  (-1*(H/Q)),pow(2.0,(-1*(H/Q))));
356 }
357 
358 static LISP wagon_vector_predict(LISP tree, LISP vector, LISP description)
359 {
360  // Using the LISP tree, vector and description, do standard prediction
361 
362  if (cdr(tree) == NIL)
363  return car(tree);
364 
365  LISP value = find_feature_value(wgn_ques_feature(car(tree)),
366  vector, description);
367 
368  if (wagon_ask_question(car(tree),value))
369  // Yes answer
370  return wagon_vector_predict(car(cdr(tree)),vector,description);
371  else
372  // No answer
373  return wagon_vector_predict(car(cdr(cdr(tree))),vector,description);
374 }
375 
376 static LISP find_feature_value(const char *feature,
377  LISP vector, LISP description)
378 {
379  LISP v,d;
380 
381  for (v=vector,d=description; v != NIL; v=cdr(v),d=cdr(d))
382  if (streq(feature,get_c_string(car(car(d)))))
383  return car(v);
384 
385  cerr << "wagon_test: can't find feature \"" << feature <<
386  "\" in description" << endl;
387  siod_error();
388  return NIL;
389 
390 }
391 
int row(void) const
Line number in original EST_TokenStream.
Definition: EST_Token.h:186
EST_TokenStream & get(EST_Token &t)
get next token in stream
Definition: EST_Token.cc:486
double stddev(void) const
standard deviation of currently cummulated values
int ival(const EST_String &rkey, int m=1) const
Definition: EST_Option.cc:76
double mean(void) const
mean of currently cummulated values
void close(void)
Close stream.
Definition: EST_Token.cc:406
int open(const EST_String &filename)
open a EST_TokenStream for a file.
Definition: EST_Token.cc:200
int eof()
end of file
Definition: EST_Token.h:363
EST_read_status load(const EST_String name, float ishift=0.0, float startt=0.0)
Definition: EST_Track.cc:1309
const int present(const K &rkey) const
Returns true if key is present.
Definition: EST_TKVL.cc:222
const V & val(const K &rkey, bool m=0) const
return value according to key (const)
Definition: EST_TKVL.cc:145
int add_item(const K &rkey, const V &rval, int no_search=0)
add key-val pair to list
Definition: EST_TKVL.cc:248
void append(const T &item)
add item onto end of list
Definition: EST_TList.h:198
int contains(const char *s, int pos=-1) const
Does it contain this substring?
Definition: EST_String.h:378