Edinburgh Speech Tools  2.1-release
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
wagon_main.cc
1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1996-2006 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : May 1996 */
35 /*-----------------------------------------------------------------------*/
36 /* A Classification and Regression Tree (CART) Program */
37 /* A basic implementation of many of the techniques in */
38 /* Briemen et al. 1984 */
39 /* */
40 /* Added decision list support, Feb 1997 */
41 /* */
42 /* Added vector support for Clustergen 2005/2006 */
43 /* */
44 /*=======================================================================*/
45 #include <cstdlib>
46 #include <iostream>
47 #include <fstream>
48 #include <cstring>
49 #include "EST_Wagon.h"
50 #include "EST_cmd_line.h"
51 
52 enum wn_strategy_type {wn_decision_list, wn_decision_tree};
53 
54 static wn_strategy_type wagon_type = wn_decision_tree;
55 
56 static int wagon_main(int argc, char **argv);
57 
58 
59 int main(int argc, char **argv)
60 {
61 
62  wagon_main(argc,argv);
63 
64  exit(0);
65  return 0;
66 }
67 
68 static int set_Vertex_Feats(EST_Track &wgn_VertexFeats,
69  EST_String &wagon_track_features)
70 {
71  int i,s=0,e;
72  EST_TokenStream ts;
73 
74  for (i=0; i<wgn_VertexFeats.num_channels(); i++)
75  wgn_VertexFeats.a(0,i) = 0.0;
76 
77  ts.open_string(wagon_track_features);
78  ts.set_WhiteSpaceChars(",- ");
81  ts.set_SingleCharSymbols("");
82 
83  while (!ts.eof())
84  {
85  EST_Token &token = ts.get();
86  const EST_String ws = (const char *)token.whitespace();
87  if (token == "all")
88  {
89  for (i=0; i<wgn_VertexFeats.num_channels(); i++)
90  wgn_VertexFeats.a(0,i) = 1.0;
91  break;
92  } else if ((ws == ",") || (ws == ""))
93  {
94  s = atoi(token.string());
95  wgn_VertexFeats.a(0,s) = 1.0;
96  } else if (ws == "-")
97  {
98  if (token == "")
99  e = wgn_VertexFeats.num_channels()-1;
100  else
101  e = atoi(token.string());
102  for (i=s; i<=e && i<wgn_VertexFeats.num_channels(); i++)
103  wgn_VertexFeats.a(0,i) = 1.0;
104  } else
105  {
106  printf("wagon: track_feats invalid: %s at position %d\n",
107  (const char *)wagon_track_features,
108  ts.filepos());
109  exit(-1);
110  }
111  }
112 
113  return 0;
114 }
115 
116 static int wagon_main(int argc, char **argv)
117 {
118  // Top level function sets up data and creates a tree
119  EST_Option al;
120  EST_StrList files;
121  EST_String wgn_oname;
122  ostream *wgn_coutput = 0;
123  float stepwise_limit = 0;
124  int feats_start=0, feats_end=0;
125  int i;
126 
127  parse_command_line
128  (argc, argv,
129  EST_String("[options]\n") +
130  "Summary: CART building program\n"+
131  "-desc <ifile> Field description file\n"+
132  "-data <ifile> Datafile, one vector per line\n"+
133  "-stop <int> {50} Minimum number of examples for leaf nodes\n"+
134  "-test <ifile> Datafile to test tree on\n"+
135  "-frs <float> {10} Float range split, number of partitions to\n"+
136  " split a float feature range into\n"+
137  "-dlist Build a decision list (rather than tree)\n"+
138  "-dtree Build a decision tree (rather than list) default\n"+
139  "-output <ofile> \n"+
140  "-o <ofile> File to save output tree in\n"+
141  "-distmatrix <ifile>\n"+
142  " A distance matrix for clustering\n"+
143  "-track <ifile>\n"+
144  " track for vertex indices\n"+
145  "-track_start <int>\n"+
146  " start channel vertex indices\n"+
147  "-track_end <int>\n"+
148  " end (inclusive) channel for vertex indices\n"+
149  "-track_feats <string>\n"+
150  " Track features to use, comma separated list\n"+
151  " with feature numbers and/or ranges, 0 start\n"+
152  "-unittrack <ifile>\n"+
153  " track for unit start and length in vertex track\n"+
154  "-quiet No questions printed during building\n"+
155  "-verbose Lost of information printing during build\n"+
156  "-predictee <string>\n"+
157  " name of field to predict (default is first field)\n"+
158  "-ignore <string>\n"+
159  " Filename or bracket list of fields to ignore\n"+
160  "-count_field <string>\n"+
161  " Name of field containing count weight for samples\n"+
162  "-stepwise Incrementally find best features\n"+
163  "-swlimit <float> {0.0}\n"+
164  " Percentage necessary improvement for stepwise,\n"+
165  " may be negative.\n"+
166  "-swopt <string> Parameter to optimize for stepwise, for \n"+
167  " classification options are correct or entropy\n"+
168  " for regression options are rmse or correlation\n"+
169  " correct and correlation are the defaults\n"+
170  "-balance <float> For derived stop size, if dataset at node, divided\n"+
171  " by balance is greater than stop it is used as stop\n"+
172  " if balance is 0 (default) always use stop as is.\n"+
173  "-vertex_output <string> Output <mean> or <best> of cluster\n"+
174  "-held_out <int> Percent to hold out for pruning\n"+
175  "-heap <int> {210000}\n"+
176  " Set size of Lisp heap, should not normally need\n"+
177  " to be changed from its default, only with *very*\n"+
178  " large description files (> 1M)\n"+
179  "-noprune No (same class) pruning required\n",
180  files, al);
181 
182  if (al.present("-held_out"))
183  wgn_held_out = al.ival("-held_out");
184  if (al.present("-balance"))
185  wgn_balance = al.fval("-balance");
186  if ((!al.present("-desc")) || ((!al.present("-data"))))
187  {
188  cerr << argv[0] << ": missing description and/or datafile" << endl;
189  cerr << "use -h for description of arguments" << endl;
190  }
191 
192  if (al.present("-quiet"))
193  wgn_quiet = TRUE;
194  if (al.present("-verbose"))
195  wgn_verbose = TRUE;
196 
197  if (al.present("-stop"))
198  wgn_min_cluster_size = atoi(al.val("-stop"));
199  if (al.present("-noprune"))
200  wgn_prune = FALSE;
201  if (al.present("-predictee"))
202  wgn_predictee_name = al.val("-predictee");
203  if (al.present("-count_field"))
204  wgn_count_field_name = al.val("-count_field");
205  if (al.present("-swlimit"))
206  stepwise_limit = al.fval("-swlimit");
207  if (al.present("-frs")) // number of partitions to try in floats
208  wgn_float_range_split = atof(al.val("-frs"));
209  if (al.present("-swopt"))
210  wgn_opt_param = al.val("-swopt");
211  if (al.present("-vertex_output"))
212  wgn_vertex_output = al.val("-vertex_output");
213  if (al.present("-output") || al.present("-o"))
214  {
215  if (al.present("-o"))
216  wgn_oname = al.val("-o");
217  else
218  wgn_oname = al.val("-output");
219  wgn_coutput = new ofstream(wgn_oname);
220  if (!(*wgn_coutput))
221  {
222  cerr << "Wagon: can't open file \"" << wgn_oname <<
223  "\" for output " << endl;
224  exit(-1);
225  }
226  }
227  else
228  wgn_coutput = &cout;
229  if (al.present("-distmatrix"))
230  {
231  if (wgn_DistMatrix.load(al.val("-distmatrix")) != 0)
232  {
233  cerr << "Wagon: failed to load Distance Matrix from \"" <<
234  al.val("-distmatrix") << "\"\n" << endl;
235  exit(-1);
236  }
237  }
238  if (al.present("-dlist"))
239  wagon_type = wn_decision_list;
240 
241  WNode *tree;
242  float score;
243  LISP ignores = NIL;
244 
245  siod_init(al.ival("-heap"));
246 
247  if (al.present("-ignore"))
248  {
249  EST_String ig = al.val("-ignore");
250  if (ig[0] == '(')
251  ignores = read_from_string(ig);
252  else
253  ignores = vload(ig,1);
254  }
255  // Load in the data
256  wgn_load_datadescription(al.val("-desc"),ignores);
257  wgn_load_dataset(wgn_dataset,al.val("-data"));
258  if (al.present("-distmatrix") &&
259  (wgn_DistMatrix.num_rows() < wgn_dataset.length()))
260  {
261  cerr << "wagon: distance matrix is smaller than number of training elements\n";
262  exit(-1);
263  }
264  else if (al.present("-track"))
265  {
266  wgn_VertexTrack.load(al.val("-track"));
267  wgn_VertexFeats.resize(1,wgn_VertexTrack.num_channels());
268  for (i=0; i<wgn_VertexFeats.num_channels(); i++)
269  wgn_VertexFeats.a(0,i) = 1.0;
270  }
271 
272  if (al.present("-track_start"))
273  {
274  feats_start = al.ival("-track_start");
275  if ((feats_start < 0) ||
276  (feats_start > wgn_VertexTrack.num_channels()))
277  {
278  printf("wagon: track_start invalid: %d out of %d channels\n",
279  feats_start,
280  wgn_VertexTrack.num_channels());
281  exit(-1);
282  }
283  for (i=0; i<feats_start; i++)
284  wgn_VertexFeats.a(0,i) = 0.0; /* don't do feats up to start */
285 
286  }
287 
288  if (al.present("-track_end"))
289  {
290  feats_end = al.ival("-track_end");
291  if ((feats_end < feats_start) ||
292  (feats_end > wgn_VertexTrack.num_channels()))
293  {
294  printf("wagon: track_end invalid: %d between start %d out of %d channels\n",
295  feats_end,
296  feats_start,
297  wgn_VertexTrack.num_channels());
298  exit(-1);
299  }
300  for (i=feats_end+1; i<wgn_VertexTrack.num_channels(); i++)
301  wgn_VertexFeats.a(0,i) = 0.0; /* don't do feats after end */
302  }
303  if (al.present("-track_feats"))
304  { /* overrides start and end numbers */
305  EST_String wagon_track_features = al.val("-track_feats");
306  set_Vertex_Feats(wgn_VertexFeats,wagon_track_features);
307  }
308 
309  // printf("Track feats\n");
310  // for (i=0; i<wgn_VertexTrack.num_channels(); i++)
311  // if (wgn_VertexFeats.a(0,i) > 0.0)
312  // printf("%d ",i);
313  // printf("\n");
314 
315  if (al.present("-unittrack"))
316  { /* contains two features, a start and length. start indexes */
317  /* into VertexTrack to the first vector in the segment */
318  wgn_UnitTrack.load(al.val("-unittrack"));
319  }
320 
321  if (al.present("-test"))
322  wgn_load_dataset(wgn_test_dataset,al.val("-test"));
323 
324  // Build and test the model
325  if (al.present("-stepwise"))
326  tree = wagon_stepwise(stepwise_limit);
327  else if (wagon_type == wn_decision_tree)
328  tree = wgn_build_tree(score); // default operation
329  else if (wagon_type == wn_decision_list)
330  // dlist is printed with build_dlist rather than returned
331  tree = wgn_build_dlist(score,wgn_coutput);
332  else
333  {
334  cerr << "Wagon: unknown operation, not tree or list" << endl;
335  exit(-1);
336  }
337 
338  if (tree != 0)
339  {
340  *wgn_coutput << *tree;
341  summary_results(*tree,wgn_coutput);
342  }
343 
344  if (wgn_coutput != &cout)
345  delete wgn_coutput;
346  return 0;
347 }
348 
void set_WhiteSpaceChars(const EST_String &ws)
set which characters are to be treated as whitespace
Definition: EST_Token.h:342
EST_TokenStream & get(EST_Token &t)
get next token in stream
Definition: EST_Token.cc:486
float & a(int i, int c=0)
Definition: EST_Track.cc:1022
int ival(const EST_String &rkey, int m=1) const
Definition: EST_Option.cc:76
int num_channels() const
return number of channels in track
Definition: EST_Track.h:656
void set_SingleCharSymbols(const EST_String &sc)
set which characters are to be treated as single character symbols
Definition: EST_Token.h:345
float fval(const EST_String &rkey, int m=1) const
Definition: EST_Option.cc:98
void set_PrePunctuationSymbols(const EST_String &ps)
set which characters are to be treated as (post) punctuation
Definition: EST_Token.h:351
int open_string(const EST_String &newbuffer)
open a EST_TokenStream for string rather than a file
Definition: EST_Token.cc:251
void set_PunctuationSymbols(const EST_String &ps)
set which characters are to be treated as (post) punctuation
Definition: EST_Token.h:348
int eof()
end of file
Definition: EST_Token.h:363
EST_read_status load(const EST_String name, float ishift=0.0, float startt=0.0)
Definition: EST_Track.cc:1309
void resize(int num_frames, int num_channels, bool preserve=1)
Definition: EST_Track.cc:211
EST_read_status load(const EST_String &filename)
Load from file (ascii or binary as defined in file)
Definition: EST_FMatrix.cc:513
const int present(const K &rkey) const
Returns true if key is present.
Definition: EST_TKVL.cc:222
const V & val(const K &rkey, bool m=0) const
return value according to key (const)
Definition: EST_TKVL.cc:145
int filepos(void) const
current file position in EST_TokenStream
Definition: EST_Token.h:368
int num_rows() const
return number of rows
Definition: EST_TMatrix.h:178