54 static void load_vocab(
const EST_String &vfile);
58 static void load_wstream(
const EST_String &filename,
63 static void load_given(
const EST_String &filename,
64 const int ngram_order);
66 static double find_gram_prob(
EST_VTPath *p,
int *state);
69 static double find_extra_gram_prob(
EST_VTPath *p,
int *state,
int time);
73 static int is_a_special(
const EST_String &s,
int &val);
74 static int max_history=0;
77 static EST_String pstring = SENTENCE_START_MARKER;
78 static EST_String ppstring = SENTENCE_END_MARKER;
79 static float lm_scale = 1.0;
80 static float ob_scale = 1.0;
81 static float ob_scale2 = 1.0;
85 static float ob_beam=-1;
86 static int n_beam = -1;
88 static bool trace_on = FALSE;
91 static double ob_log_prob_floor = SAFE_LOG_ZERO;
92 static double ob_log_prob_floor2 = SAFE_LOG_ZERO;
93 static double lm_log_prob_floor = SAFE_LOG_ZERO;
95 int btest_debug = FALSE;
101 int using_given=FALSE;
104 int take_logs = FALSE;
108 int main(
int argc,
char **argv)
115 parse_command_line(argc, argv,
116 EST_String(
"[observations file] -o [output file]\n")+
117 "Summary: find the most likely path through a sequence of\n"+
118 " observations, constrained by a language model.\n"+
119 "-ngram <string> Grammar file, required\n"+
120 "-given <string> ngram left contexts, per frame\n"+
121 "-vocab <string> File with names of vocabulary, this\n"+
122 " must be same number as width of observations, required\n"+
123 "-ob_type <string> Observation type : likelihood .... and change doc\"probs\" or \"logs\" (default is \"logs\")\n"+
124 "\nFloor values and scaling (scaling is applied after floor value)\n"+
125 "-lm_floor <float> LM floor probability\n"+
126 "-lm_scale <float> LM scale factor factor (applied to log prob)\n"+
127 "-ob_floor <float> Observations floor probability\n"+
128 "-ob_scale <float> Observation scale factor (applied to prob or log prob, depending on -ob_type)\n\n"+
129 "-prev_tag <string>\n"+
130 " tag before sentence start\n"+
131 "-prev_prev_tag <string>\n"+
132 " all words before 'prev_tag'\n"+
133 "-last_tag <string>\n"+
134 " after sentence end\n"+
135 "-default_tags use default tags of "+SENTENCE_START_MARKER+
","
136 SENTENCE_END_MARKER+
" and "+SENTENCE_END_MARKER+
"\n"+
139 "-observes2 <string> second observations (overlays first, ob_type must be same)\n"+
140 "-ob_floor2 <float> \n"+
141 "-ob_scale2 <float> \n\n"+
142 "-ob_prune <float> observation pruning beam width (log) probability\n"+
143 "-n_prune <int> top-n pruning of observations\n"+
144 "-prune <float> pruning beam width (log) probability\n"+
145 "-trace show details of search as it proceeds\n",
150 if (files.length() != 1)
153 cerr <<
": you must give exactly one observations file on the command line";
155 cerr <<
"(use -observes2 for optional second observations)" << endl;
161 ngram.load(al.
val(
"-ngram"));
165 cerr << argv[0] <<
": no ngram specified" << endl;
171 cerr <<
"You must provide a vocabulary file !" << endl;
175 load_wstream(files.
first(),al.
val(
"-vocab"),wstream,observations);
178 load_wstream(al.
val(
"-observes2"),al.
val(
"-vocab"),wstream,observations2);
184 load_given(al.
val(
"-given"),ngram.order());
189 lm_scale = al.
fval(
"-lm_scale");
194 ob_scale = al.
fval(
"-ob_scale");
199 ob_scale2 = al.
fval(
"-ob_scale2");
204 pstring = al.
val(
"-prev_tag");
205 if (al.
present(
"-prev_prev_tag"))
206 ppstring = al.
val(
"-prev_prev_tag");
210 beam = al.
fval(
"-prune");
215 ob_beam = al.
fval(
"-ob_prune");
221 n_beam = al.
ival(
"-n_prune");
224 cerr <<
"WARNING : " << n_beam;
225 cerr <<
" is not a reasonable value for -n_prune !" << endl;
239 floor = al.
fval(
"-lm_floor");
242 cerr <<
"Error : LM floor probability is negative !" << endl;
247 cerr <<
"Error : LM floor probability > 1 " << endl;
250 lm_log_prob_floor = safe_log(floor);
256 floor = al.
fval(
"-ob_floor");
259 cerr <<
"Error : Observation floor probability is negative !" << endl;
264 cerr <<
"Error : Observation floor probability > 1 " << endl;
267 ob_log_prob_floor = safe_log(floor);
272 floor = al.
fval(
"-ob_floor2");
275 cerr <<
"Error : Observation2 floor probability is negative !" << endl;
280 cerr <<
"Error : Observation2 floor probability > 1 " << endl;
283 ob_log_prob_floor2 = safe_log(floor);
289 if(al.
val(
"-ob_type") ==
"logs")
291 else if(al.
val(
"-ob_type") ==
"probs")
295 cerr <<
"\"" << al.
val(
"-ob_type")
296 <<
"\" is not a valid ob_type : try \"logs\" or \"probs\"" << endl;
301 if(do_search(wstream))
302 print_results(wstream);
304 cerr <<
"No path could be found." << endl;
318 else if ((fd = fopen(out_file,
"wb")) == NULL)
320 cerr <<
"can't open \"" << out_file <<
"\" for output" << endl;
324 for (s=wstream.
head(); s != 0 ; s=s->next())
326 predict = s->f(
"best").
string();
327 pscore = s->f(
"best_score");
328 fprintf(fd,
"%s %f\n",(
const char *)predict,pscore);
341 states = ngram.num_states();
344 vc.initialise(&wstream);
346 if((beam > 0) || (ob_beam > 0))
347 vc.set_pruning_parameters(beam,ob_beam);
352 cerr <<
"Starting Viterbi search..." << endl;
357 return vc.result(
"best");
361 static void load_wstream(
const EST_String &filename,
373 if (obs.
load(filename,0.10) != 0)
375 cerr <<
"can't find observations file \"" << filename <<
"\"" << endl;
381 cerr <<
"Number in vocab (" << vocab.length() <<
382 ") not equal to observation's width (" <<
393 static void load_given(
const EST_String &filename,
394 const int ngram_order)
401 if (load_TList_of_StrVector(given,filename,ngram_order-1) != 0)
403 cerr <<
"can't load given file \"" << filename <<
"\"" << endl;
408 for (p = given.head(); p; p = p->next())
410 for(i=0;i<given(p).length();i++)
411 if( is_a_special( given(p)(i), j) && (-j > max_history))
418 static void load_vocab(
const EST_String &vfile)
423 if (ts.
open(vfile) == -1)
425 cerr <<
"can't find vocab file \"" << vfile <<
"\"" << endl;
440 item->set_name(word);
441 item->
set(
"pos",pos);
447 double prob=1.0,prob2=1.0;
455 observe = s->f(
"pos");
456 for (i=0,p=vocab.head(); i < observations.
num_channels(); i++,p=p->next())
460 prob = observations.
a(observe,i);
462 prob2 = observations2.
a(observe,i);
466 prob = safe_log10(prob);
467 if (prob < ob_log_prob_floor)
468 prob = ob_log_prob_floor;
472 prob2 = safe_log10(prob2);
473 if (prob2 < ob_log_prob_floor2)
474 prob2 = ob_log_prob_floor2;
479 if (prob < ob_log_prob_floor)
480 prob = ob_log_prob_floor;
481 if ((num_obs == 2) && (prob2 < ob_log_prob_floor2))
482 prob2 = ob_log_prob_floor2;
489 c->score = prob + prob2;
501 top_n_candidates(all_c);
524 prob = find_extra_gram_prob(np,&np->state,c->s->f(
"pos"));
526 prob = find_gram_prob(np,&np->state);
528 lprob = safe_log10(prob);
529 if (lprob < lm_log_prob_floor)
530 lprob = lm_log_prob_floor;
534 np->f.
set(
"lscore",(c->score+lprob));
536 np->score = (c->score+lprob);
538 np->score = (c->score+lprob) + p->score;
543 static double find_gram_prob(
EST_VTPath *p,
int *state)
547 double prob=0.0,nprob;
552 for (pp=p->from,i=ngram.order()-2; i >= 0; i--)
556 window[i] = pp->c->name.
string();
560 window[i] = ppstring;
567 window[ngram.order()-1] = p->c->name.
string();
572 prob = (double)pd.probability(p->c->name.
string());
574 for (i=0; i < ngram.order()-1; i++)
575 window[i] = window(i+1);
576 ngram.predict(window,&nprob,state);
582 static double find_extra_gram_prob(
EST_VTPath *p,
int *state,
int time)
586 double prob=0.0,nprob;
590 get_history(history,p);
592 fill_window(window,history,p,time);
605 prob = (double)pd.probability(p->c->name.
string());
610 for(i=history.length()-1;i>0;i--)
611 history[i] = history(i-1);
612 history[0] = p->c->name.
string();
615 fill_window(window,history,p,time+1);
616 ngram.predict(window,&nprob,state);
629 for (pp=p->from,i=0; i < history.
length(); i++)
634 history[i] = pp->c->name.
string();
638 history[i] = ppstring;
641 history[i] = pstring;
657 if( time >= given.length() )
667 window[ngram.order()-1] = p->c->name.
string();
673 for(i=0;i<ngram.order()-1;i++)
676 if( is_a_special( (*this_g)(i), j))
677 window[i] = history(-1-j);
679 window[i] = (*this_g)(i);
685 static int is_a_special(
const EST_String &s,
int &val)
716 for(i=0;i<n_beam;i++)
725 for(p=all_c;p!= NULL;q=p,p=p->next)
728 if(p->score > this_best->score)
735 if(this_best == NULL)
739 if(prev_to_best == NULL)
741 all_c = this_best->next;
744 prev_to_best->next = this_best->next;
746 this_best->next = top_c;
const T & first() const
return const reference to first item in list
EST_TokenStream & get(EST_Token &t)
get next token in stream
float & a(int i, int c=0)
double samples(void) const
Total number of example found.
int ival(const EST_String &rkey, int m=1) const
int num_channels() const
return number of channels in track
void close(void)
Close stream.
EST_String itoString(int n)
Make a EST_String object from an integer.
float fval(const EST_String &rkey, int m=1) const
void set(const EST_String &name, int ival)
int open(const EST_String &filename)
open a EST_TokenStream for a file.
INLINE int length() const
number of items in vector.
void set(const EST_String &name, int ival)
EST_read_status load(const EST_String name, float ishift=0.0, float startt=0.0)
T & nth(int n)
return the Nth value
const int present(const K &rkey) const
Returns true if key is present.
EST_Token & peek(void)
peek at next token
const EST_String & string(void) const
const V & val(const K &rkey, bool m=0) const
return value according to key (const)
void append(const T &item)
add item onto end of list
int contains(const char *s, int pos=-1) const
Does it contain this substring?
A class that offers a generalised Viterbi decoder.
int num_frames() const
return number of frames in track
EST_String after(int pos, int len=1) const
Part after pos+len.
EST_String before(int pos, int len=0) const
Part before position.