Edinburgh Speech Tools  2.1-release
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
wagon_aux.cc
1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1996,1997 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : May 1996 */
35 /*-----------------------------------------------------------------------*/
36 /* */
37 /* Various method functions */
38 /*=======================================================================*/
39 
40 #include <cstdlib>
41 #include <iostream>
42 #include <cstring>
43 #include "EST_unix.h"
44 #include "EST_cutils.h"
45 #include "EST_Token.h"
46 #include "EST_Wagon.h"
47 #include "EST_math.h"
48 
49 
50 EST_Val WNode::predict(const WVector &d)
51 {
52  if (leaf())
53  return impurity.value();
54  else if (question.ask(d))
55  return left->predict(d);
56  else
57  return right->predict(d);
58 }
59 
60 WNode *WNode::predict_node(const WVector &d)
61 {
62  if (leaf())
63  return this;
64  else if (question.ask(d))
65  return left->predict_node(d);
66  else
67  return right->predict_node(d);
68 }
69 
70 int WNode::pure(void)
71 {
72  // A node is pure if it has no sub-nodes or its not of type class
73 
74  if ((left == 0) && (right == 0))
75  return TRUE;
76  else if (get_impurity().type() != wnim_class)
77  return TRUE;
78  else
79  return FALSE;
80 }
81 
82 void WNode::prune(void)
83 {
84  // Check all sub-nodes and if they are all of the same class
85  // delete their sub nodes. Returns pureness of this node
86 
87  if (pure() == FALSE)
88  {
89  // Ok lets try and make it pure
90  if (left != 0) left->prune();
91  if (right != 0) right->prune();
92 
93  // Have to check purity as well as values to ensure left and right
94  // don't further split
95  if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
96  (left->get_impurity().value() == right->get_impurity().value()))
97  {
98  delete left; left = 0;
99  delete right; right = 0;
100  }
101  }
102 
103 }
104 
105 void WNode::held_out_prune()
106 {
107  // prune tree with held out data
108  // Check if node's questions differentiates for the held out data
109  // if not, prune all sub_nodes
110 
111  // Rescore with prune data
112  set_impurity(WImpurity(get_data())); // for this new data
113 
114  if (left != 0)
115  {
116  wgn_score_question(question,get_data());
117  if (question.get_score() < get_impurity().measure())
118  { // its worth goint ot the next level
119  wgn_find_split(question,get_data(),
120  left->get_data(),
121  right->get_data());
122  left->held_out_prune();
123  right->held_out_prune();
124  }
125  else
126  { // not worth the split so prune both sub_nodes
127  delete left; left = 0;
128  delete right; right = 0;
129  }
130  }
131 }
132 
133 void WNode::print_out(ostream &s, int margin)
134 {
135  int i;
136 
137  s << endl;
138  for (i=0;i<margin;i++) s << " ";
139  s << "(";
140  if (left==0) // base case
141  s << impurity;
142  else
143  {
144  s << question;
145  left->print_out(s,margin+1);
146  right->print_out(s,margin+1);
147  }
148  s << ")";
149 }
150 
151 ostream & operator <<(ostream &s, WNode &n)
152 {
153  // Output this node and its sub-node
154 
155  n.print_out(s,0);
156  s << endl;
157  return s;
158 }
159 
160 void WDataSet::ignore_non_numbers()
161 {
162  /* For ols we want to ignore anything that is categorial */
163  int i;
164 
165  for (i=0; i<dlength; i++)
166  {
167  if ((p_type[i] == wndt_binary) ||
168  (p_type[i] == wndt_float))
169  continue;
170  else
171  {
172  p_ignore[i] = TRUE;
173  }
174  }
175 
176  return;
177 }
178 
179 void WDataSet::load_description(const EST_String &fname, LISP ignores)
180 {
181  // Initialise a dataset with sizes and types
182  EST_String tname;
183  int i;
184  LISP description,d;
185 
186  description = car(vload(fname,1));
187  dlength = siod_llength(description);
188 
189  p_type.resize(dlength);
190  p_ignore.resize(dlength);
191  p_name.resize(dlength);
192 
193  if (wgn_predictee_name == "")
194  wgn_predictee = 0; // default predictee is first field
195  else
196  wgn_predictee = -1;
197 
198  for (i=0,d=description; d != NIL; d=cdr(d),i++)
199  {
200  p_name[i] = get_c_string(car(car(d)));
201  tname = get_c_string(car(cdr(car(d))));
202  p_ignore[i] = FALSE;
203  if ((wgn_predictee_name != "") && (wgn_predictee_name == p_name[i]))
204  wgn_predictee = i;
205  if ((wgn_count_field_name != "") &&
206  (wgn_count_field_name == p_name[i]))
207  wgn_count_field = i;
208  if ((tname == "count") || (i == wgn_count_field))
209  {
210  // The count must be ignored, repeat it if you want it too
211  p_type[i] = wndt_ignore; // the count must be ignored
212  p_ignore[i] = TRUE;
213  wgn_count_field = i;
214  }
215  else if ((tname == "ignore") || (siod_member_str(p_name[i],ignores)))
216  {
217  p_type[i] = wndt_ignore; // user specified ignore
218  p_ignore[i] = TRUE;
219  if (i == wgn_predictee)
220  wagon_error(EST_String("predictee \"")+p_name[i]+
221  "\" can't be ignored \n");
222  }
223  else if (siod_llength(car(d)) > 2)
224  {
225  LISP rest = cdr(car(d));
226  EST_StrList sl;
227  siod_list_to_strlist(rest,sl);
228  p_type[i] = wgn_discretes.def(sl);
229  if (streq(get_c_string(car(rest)),"_other_"))
230  wgn_discretes[p_type[i]].def_val("_other_");
231  }
232  else if (tname == "binary")
233  p_type[i] = wndt_binary;
234  else if (tname == "cluster")
235  p_type[i] = wndt_cluster;
236  else if (tname == "vector")
237  p_type[i] = wndt_vector;
238  else if (tname == "trajectory")
239  p_type[i] = wndt_trajectory;
240  else if (tname == "matrix")
241  p_type[i] = wndt_matrix;
242  else if (tname == "float")
243  p_type[i] = wndt_float;
244  else
245  {
246  wagon_error(EST_String("Unknown type \"")+tname+
247  "\" for field number "+itoString(i)+
248  "/"+p_name[i]+" in description file \""+fname+"\"");
249  }
250  }
251 
252  if (wgn_predictee == -1)
253  {
254  wagon_error(EST_String("predictee field \"")+wgn_predictee_name+
255  "\" not found in description ");
256  }
257 }
258 
259 const int WQuestion::ask(const WVector &w) const
260 {
261  // Ask this question of the given vector
262  switch (op)
263  {
264  case wnop_equal: // for numbers
265  if (w.get_flt_val(feature_pos) == operand1.Float())
266  return TRUE;
267  else
268  return FALSE;
269  case wnop_binary: // for numbers
270  if (w.get_int_val(feature_pos) == 1)
271  return TRUE;
272  else
273  return FALSE;
274  case wnop_greaterthan:
275  if (w.get_flt_val(feature_pos) > operand1.Float())
276  return TRUE;
277  else
278  return FALSE;
279  case wnop_lessthan:
280  if (w.get_flt_val(feature_pos) < operand1.Float())
281  return TRUE;
282  else
283  return FALSE;
284  case wnop_is: // for classes
285  if (w.get_int_val(feature_pos) == operand1.Int())
286  return TRUE;
287  else
288  return FALSE;
289  case wnop_in: // for subsets -- note operand is list of ints
290  if (ilist_member(operandl,w.get_int_val(feature_pos)))
291  return TRUE;
292  else
293  return FALSE;
294  default:
295  wagon_error("Unknown test operator");
296  }
297 
298  return FALSE;
299 }
300 
301 ostream& operator<<(ostream& s, const WQuestion &q)
302 {
303  EST_String name;
304  static EST_Regex needquotes(".*[()'\";., \t\n\r].*");
305 
306  s << "(" << wgn_dataset.feat_name(q.get_fp());
307  switch (q.get_op())
308  {
309  case wnop_equal:
310  s << " = " << q.get_operand1().string();
311  break;
312  case wnop_binary:
313  break;
314  case wnop_greaterthan:
315  s << " > " << q.get_operand1().Float();
316  break;
317  case wnop_lessthan:
318  s << " < " << q.get_operand1().Float();
319  break;
320  case wnop_is:
321  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
322  name(q.get_operand1().Int());
323  s << " is ";
324  if (name.matches(needquotes))
325  s << quote_string(name,"\"","\\",1);
326  else
327  s << name;
328  break;
329  case wnop_matches:
330  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
331  name(q.get_operand1().Int());
332  s << " matches " << quote_string(name,"\"","\\",1);
333  break;
334  case wnop_in:
335  s << " in (";
336  for (int l=0; l < q.get_operandl().length(); l++)
337  {
338  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
339  name(q.get_operandl().nth(l));
340  if (name.matches(needquotes))
341  s << quote_string(name,"\"","\\",1);
342  else
343  s << name;
344  s << " ";
345  }
346  s << ")";
347  break;
348  // SunCC wont let me add this
349 // default:
350 // s << " unknown operation ";
351  }
352  s << ")";
353 
354  return s;
355 }
356 
357 EST_Val WImpurity::value(void)
358 {
359  // Returns the recommended value for this
360  EST_String s;
361  double prob;
362 
363  if (t==wnim_unset)
364  {
365  cerr << "WImpurity: no value currently set\n";
366  return EST_Val(0.0);
367  }
368  else if (t==wnim_class)
369  return EST_Val(p.most_probable(&prob));
370  else if (t==wnim_cluster)
371  return EST_Val(a.mean());
372  else if (t==wnim_vector)
373  return EST_Val(a.mean()); /* wnim_vector */
374  else if (t==wnim_trajectory)
375  return EST_Val(a.mean()); /* NOT YET WRITTEN */
376  else
377  return EST_Val(a.mean());
378 }
379 
380 double WImpurity::samples(void)
381 {
382  if (t==wnim_float)
383  return a.samples();
384  else if (t==wnim_class)
385  return (int)p.samples();
386  else if (t==wnim_cluster)
387  return members.length();
388  else if (t==wnim_vector)
389  return members.length();
390  else if (t==wnim_trajectory)
391  return members.length();
392  else
393  return 0;
394 }
395 
396 WImpurity::WImpurity(const WVectorVector &ds)
397 {
398  int i;
399 
400  t=wnim_unset;
401  a.reset(); trajectory=0; l=0; width=0;
402  for (i=0; i < ds.n(); i++)
403  {
404  if (wgn_count_field == -1)
405  cumulate((*(ds(i)))[wgn_predictee],1);
406  else
407  cumulate((*(ds(i)))[wgn_predictee],
408  (*(ds(i)))[wgn_count_field]);
409  }
410 }
411 
412 float WImpurity::measure(void)
413 {
414  if (t == wnim_float)
415  return a.variance()*a.samples();
416  else if (t == wnim_vector)
417  return vector_impurity();
418  else if (t == wnim_trajectory)
419  return trajectory_impurity();
420  else if (t == wnim_matrix)
421  return a.variance()*a.samples();
422  else if (t == wnim_class)
423  return p.entropy()*p.samples();
424  else if (t == wnim_cluster)
425  return cluster_impurity();
426  else
427  {
428  cerr << "WImpurity: can't measure unset object" << endl;
429  return 0.0;
430  }
431 }
432 
433 float WImpurity::vector_impurity()
434 {
435  // Find the mean/stddev for all values in all vectors
436  // sum the variances and multiply them by the number of members
437  EST_Litem *pp;
438  int i,j;
439  EST_SuffStats b;
440  double count = 1;
441 
442  a.reset();
443 
444 #if 1
445  /* simple distance */
446  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
447  {
448  if (wgn_VertexFeats.a(0,j) > 0.0)
449  {
450  b.reset();
451  for (pp=members.head(); pp != 0; pp=pp->next())
452  {
453  i = members.item(pp);
454  b += wgn_VertexTrack.a(i,j);
455  }
456  a += b.stddev();
457  count = b.samples();
458  }
459  }
460 #endif
461 
462 #if 0
463  /* full covariance */
464  /* worse in listening experiments */
465  EST_SuffStats **cs;
466  int mmm;
467  cs = new EST_SuffStats *[wgn_VertexTrack.num_channels()+1];
468  for (j=0; j<=wgn_VertexTrack.num_channels(); j++)
469  cs[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
470  /* Find means for diagonal */
471  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
472  {
473  if (wgn_VertexFeats.a(0,j) > 0.0)
474  {
475  for (pp=members.head(); pp != 0; pp=pp->next())
476  cs[j][j] += wgn_VertexTrack.a(members.item(pp),j);
477  }
478  }
479  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
480  {
481  for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
482  if (wgn_VertexFeats.a(0,j) > 0.0)
483  {
484  for (pp=members.head(); pp != 0; pp=pp->next())
485  {
486  mmm = members.item(pp);
487  cs[i][j] += (wgn_VertexTrack.a(mmm,i)-cs[j][j].mean())*
488  (wgn_VertexTrack.a(mmm,j)-cs[j][j].mean());
489  }
490  }
491  }
492  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
493  {
494  for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
495  if (wgn_VertexFeats.a(0,j) > 0.0)
496  a += cs[i][j].stddev();
497  }
498  count = cs[0][0].samples();
499 #endif
500 
501 #if 0
502  // look at mean euclidean distance between vectors
503  EST_Litem *qq;
504  int x,y;
505  double d,q;
506  count = 0;
507  for (pp=members.head(); pp != 0; pp=pp->next())
508  {
509  x = members.item(pp);
510  count++;
511  for (qq=pp->next(); qq != 0; qq=qq->next())
512  {
513  y = members.item(qq);
514  for (q=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
515  if (wgn_VertexFeats.a(0,j) > 0.0)
516  {
517  d = wgn_VertexTrack(x,j)-wgn_VertexTrack(y,j);
518  q += d*d;
519  }
520  a += sqrt(q);
521  }
522 
523  }
524 #endif
525 
526  // This is sum of stddev*samples
527  return a.mean() * count;
528 }
529 
530 WImpurity::~WImpurity()
531 {
532  int j;
533 
534  if (trajectory != 0)
535  {
536  for (j=0; j<l; j++)
537  delete [] trajectory[j];
538  delete [] trajectory;
539  trajectory = 0;
540  l = 0;
541  }
542 }
543 
544 
545 float WImpurity::trajectory_impurity()
546 {
547  // Find the mean length of all the units in the cluster
548  // Create that number of points
549  // Interpolate each unit to that number of points
550  // collect means and standard deviations for each point
551  // impurity is sum of the variance for each point and each coef
552  // multiplied by the number of units.
553  EST_Litem *pp;
554  int i, j;
555  int s, ti, ni, q;
556  int s1l, s2l;
557  double n, m, m1, m2, w;
558  EST_SuffStats lss, stdss;
559  EST_SuffStats l1ss, l2ss;
560  int l1, l2;
561  int ola=0;
562 
563  if (trajectory != 0)
564  { /* already done this */
565  return score;
566  }
567 
568  lss.reset();
569  l = 0;
570  for (pp=members.head(); pp != 0; pp=pp->next())
571  {
572  i = members.item(pp);
573  for (q=0; q<wgn_UnitTrack.a(i,1); q++)
574  {
575  ni = (int)wgn_UnitTrack.a(i,0)+q;
576  if (wgn_VertexTrack.a(ni,0) == -1.0)
577  {
578  l1ss += q;
579  ola = 1;
580  break;
581  }
582  }
583  if (q==wgn_UnitTrack.a(i,1))
584  { /* can't find -1 center point, so put all in l2 */
585  l1ss += 0;
586  l2ss += q;
587  }
588  else
589  l2ss += wgn_UnitTrack.a(i,1) - (q+1) - 1;
590  lss += wgn_UnitTrack.a(i,1); /* length of each unit in the cluster */
591  if (wgn_UnitTrack.a(i,1) > l)
592  l = (int)wgn_UnitTrack.a(i,1);
593  }
594 
595  if (ola==0) /* no -1's so its not an ola type cluster */
596  {
597  l = ((int)lss.mean() < 7) ? 7 : (int)lss.mean();
598 
599  /* a list of SuffStats on for each point in the trajectory */
600  trajectory = new EST_SuffStats *[l];
601  width = wgn_VertexTrack.num_channels()+1;
602  for (j=0; j<l; j++)
603  trajectory[j] = new EST_SuffStats[width];
604 
605  for (pp=members.head(); pp != 0; pp=pp->next())
606  { /* for each unit */
607  i = members.item(pp);
608  m = (float)wgn_UnitTrack.a(i,1)/(float)l; /* find interpolation */
609  s = (int)wgn_UnitTrack.a(i,0); /* start point */
610  for (ti=0,n=0.0; ti<l; ti++,n+=m)
611  {
612  ni = (int)n; // hmm floor or nint ??
613  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
614  {
615  if (wgn_VertexFeats.a(0,j) > 0.0)
616  trajectory[ti][j] += wgn_VertexTrack.a(s+ni,j);
617  }
618  }
619  }
620 
621  /* find sum of sum of stddev for all coefs of all traj points */
622  stdss.reset();
623  for (ti=0; ti<l; ti++)
624  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
625  {
626  if (wgn_VertexFeats.a(0,j) > 0.0)
627  stdss += trajectory[ti][j].stddev();
628  }
629 
630  // This is sum of all stddev * samples
631  score = stdss.mean() * members.length();
632  }
633  else
634  { /* OLA model */
635  l1 = (l1ss.mean() < 10.0) ? 10 : (int)l1ss.mean();
636  l2 = (l2ss.mean() < 10.0) ? 10 : (int)l2ss.mean();
637  l = l1 + l2 + 1 + 1;
638 
639  /* a list of SuffStats on for each point in the trajectory */
640  trajectory = new EST_SuffStats *[l];
641  for (j=0; j<l; j++)
642  trajectory[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
643 
644  for (pp=members.head(); pp != 0; pp=pp->next())
645  { /* for each unit */
646  i = members.item(pp);
647  s1l = 0;
648  s = (int)wgn_UnitTrack.a(i,0); /* start point */
649  for (q=0; q<wgn_UnitTrack.a(i,1); q++)
650  if (wgn_VertexTrack.a(s+q,0) == -1.0)
651  {
652  s1l = q; /* printf("awb q is -1 at %d\n",q); */
653  break;
654  }
655  s2l = (int)wgn_UnitTrack.a(i,1) - (s1l + 2);
656  m1 = (float)(s1l)/(float)l1; /* find interpolation step */
657  m2 = (float)(s2l)/(float)l2; /* find interpolation step */
658  /* First half */
659  for (ti=0,n=0.0; s1l > 0 && ti<l1; ti++,n+=m1)
660  {
661  ni = s + (((int)n < s1l) ? (int)n : s1l - 1);
662  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
663  if (wgn_VertexFeats.a(0,j) > 0.0)
664  trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
665  }
666  ti = l1; /* do it explicitly in case s1l < 1 */
667  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
668  if (wgn_VertexFeats.a(0,j) > 0.0)
669  trajectory[ti][j] += -1;
670  /* Second half */
671  s += s1l+1;
672  for (ti++,n=0.0; s2l > 0 && ti<l-1; ti++,n+=m2)
673  {
674  ni = s + (((int)n < s2l) ? (int)n : s2l - 1);
675  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
676  if (wgn_VertexFeats.a(0,j) > 0.0)
677  trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
678  }
679  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
680  if (wgn_VertexFeats.a(0,j) > 0.0)
681  trajectory[ti][j] += -2;
682  }
683 
684  /* find sum of sum of stddev for all coefs of all traj points */
685  /* windowing the sums with a triangular weight window */
686  stdss.reset();
687  m = 1.0/(float)l1;
688  for (w=0.0,ti=0; ti<l1; ti++,w+=m)
689  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
690  if (wgn_VertexFeats.a(0,j) > 0.0)
691  stdss += trajectory[ti][j].stddev() * w;
692  m = 1.0/(float)l2;
693  for (w=1.0,ti++; ti<l-1; ti++,w-=m)
694  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
695  if (wgn_VertexFeats.a(0,j) > 0.0)
696  stdss += trajectory[ti][j].stddev() * w;
697 
698  // This is sum of all stddev * samples
699  score = stdss.mean() * members.length();
700  }
701  return score;
702 }
703 
704 float WImpurity::cluster_impurity()
705 {
706  // Find the mean distance between all members of the dataset
707  // Uses the global DistMatrix for distances between members of
708  // the cluster set. Distances are assumed to be symmetric thus only
709  // the bottom half of the distance matrix is filled
710  EST_Litem *pp, *q;
711  int i,j;
712  double dist;
713 
714  a.reset();
715  for (pp=members.head(); pp != 0; pp=pp->next())
716  {
717  i = members.item(pp);
718  for (q=pp->next(); q != 0; q=q->next())
719  {
720  j = members.item(q);
721  dist = (j < i ? wgn_DistMatrix.a_no_check(i,j) :
722  wgn_DistMatrix.a_no_check(j,i));
723  a+=dist; // cumulate for whole cluster
724  }
725  }
726 
727  // This is sum distance between cross product of members
728 // return a.sum();
729  if (a.samples() > 1)
730  return a.stddev() * a.samples();
731  else
732  return 0.0;
733 }
734 
735 float WImpurity::cluster_distance(int i)
736 {
737  // Distance this unit is from all others in this cluster
738  // in absolute standard deviations from the the mean.
739  float dist = cluster_member_mean(i);
740  float mdist = dist-a.mean();
741 
742  if (mdist == 0.0)
743  return 0.0;
744  else
745  return fabs((dist-a.mean())/a.stddev());
746 
747 }
748 
749 int WImpurity::in_cluster(int i)
750 {
751  // Would this be a member of this cluster?. Returns 1 if
752  // its distance is less than at least one other
753  float dist = cluster_member_mean(i);
754  EST_Litem *pp;
755 
756  for (pp=members.head(); pp != 0; pp=pp->next())
757  {
758  if (dist < cluster_member_mean(members.item(pp)))
759  return 1;
760  }
761  return 0;
762 }
763 
764 float WImpurity::cluster_ranking(int i)
765 {
766  // Position in ranking closest to centre
767  float dist = cluster_distance(i);
768  EST_Litem *pp;
769  int ranking = 1;
770 
771  for (pp=members.head(); pp != 0; pp=pp->next())
772  {
773  if (dist >= cluster_distance(members.item(pp)))
774  ranking++;
775  }
776 
777  return ranking;
778 }
779 
780 float WImpurity::cluster_member_mean(int i)
781 {
782  // Returns the mean difference between this member and all others
783  // in cluster
784  EST_Litem *q;
785  int j,n;
786  double dist,sum;
787 
788  for (sum=0.0,n=0,q=members.head(); q != 0; q=q->next())
789  {
790  j = members.item(q);
791  if (i != j)
792  {
793  dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
794  sum += dist;
795  n++;
796  }
797  }
798 
799  return ( n == 0 ? 0.0 : sum/n );
800 }
801 
802 void WImpurity::cumulate(const float pv,double count)
803 {
804  // Cumulate data for impurity calculation
805 
806  if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
807  {
808  t = wnim_cluster;
809  members.append((int)pv);
810  }
811  else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
812  {
813  t = wnim_vector;
814  members.append((int)pv);
815  }
816  else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
817  {
818  t = wnim_trajectory;
819  members.append((int)pv);
820  }
821  else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
822  {
823  if (t == wnim_unset)
824  p.init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
825  t = wnim_class;
826  p.cumulate((int)pv,count);
827  }
828  else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
829  {
830  t = wnim_float;
831  a.cumulate((int)pv,count);
832  }
833  else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
834  {
835  t = wnim_float;
836  a.cumulate(pv,count);
837  }
838  else
839  {
840  wagon_error("WImpurity: cannot cumulate EST_Val type");
841  }
842 }
843 
844 ostream & operator <<(ostream &s, WImpurity &imp)
845 {
846  int j,i;
847  EST_SuffStats b;
848 
849  if (imp.t == wnim_float)
850  s << "(" << imp.a.stddev() << " " << imp.a.mean() << ")";
851  else if (imp.t == wnim_vector)
852  {
853  EST_Litem *p;
854  s << "((";
855  imp.vector_impurity();
856  if (wgn_vertex_output == "mean") //output means
857  {
858  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
859  {
860  b.reset();
861  for (p=imp.members.head(); p != 0; p=p->next())
862  {
863  b += wgn_VertexTrack.a(imp.members.item(p),j);
864  }
865  s << "(" << b.mean() << " " << b.stddev() << ")";
866  if (j+1<wgn_VertexTrack.num_channels())
867  s << " ";
868  }
869  }
870  else /* output best in the cluster */
871  {
872  /* print out vector closest to center, rather than average */
873  double best = WGN_HUGE_VAL;
874  double x,d;
875  int bestp = 0;
876  EST_SuffStats *cs;
877 
878  cs = new EST_SuffStats [wgn_VertexTrack.num_channels()+1];
879 
880  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
881  if (wgn_VertexFeats.a(0,j) > 0.0)
882  {
883  cs[j].reset();
884  for (p=imp.members.head(); p != 0; p=p->next())
885  {
886  cs[j] += wgn_VertexTrack.a(imp.members.item(p),j);
887  }
888  }
889 
890  for (p=imp.members.head(); p != 0; p=p->next())
891  {
892  for (x=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
893  if (wgn_VertexFeats.a(0,j) > 0.0)
894  {
895  d = (wgn_VertexTrack.a(imp.members.item(p),j)-cs[j].mean())
896  /* / cs[j].stddev() */ ; /* seems worse 061218 */
897  x += d*d;
898  }
899  if (x < best)
900  {
901  bestp = imp.members.item(p);
902  best = x;
903  }
904  }
905  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
906  {
907  s << "( ";
908  s << wgn_VertexTrack.a(bestp,j);
909  // s << " 0 "; // fake stddev
910  s << " ";
911  if (finite(cs[j].stddev()))
912  s << cs[j].stddev();
913  else
914  s << "0";
915  s << " ) ";
916  if (j+1<wgn_VertexTrack.num_channels())
917  s << " ";
918  }
919 
920  delete [] cs;
921  }
922  s << ") ";
923  s << imp.a.mean() << ")";
924  }
925  else if (imp.t == wnim_trajectory)
926  {
927  s << "((";
928  imp.trajectory_impurity();
929  for (i=0; i<imp.l; i++)
930  {
931  s << "(";
932  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
933  {
934  s << "(" << imp.trajectory[i][j].mean() << " "
935  << imp.trajectory[i][j].stddev() << " " << ")";
936  }
937  s << ")\n";
938  }
939  s << ") ";
940  // Mean of cross product of distances (cluster score)
941  s << imp.a.mean() << ")";
942  }
943  else if (imp.t == wnim_cluster)
944  {
945  EST_Litem *p;
946  s << "((";
947  for (p=imp.members.head(); p != 0; p=p->next())
948  {
949  // Ouput cluster member and its mean distance to others
950  s << "(" << imp.members.item(p) << " " <<
951  imp.cluster_member_mean(imp.members.item(p)) << ")";
952  if (p->next() != 0)
953  s << " ";
954  }
955  s << ") ";
956  // Mean of cross product of distances (cluster score)
957  s << imp.a.mean() << ")";
958  }
959  else if (imp.t == wnim_class)
960  {
961  EST_Litem *i;
962  EST_String name;
963  double prob;
964 
965  s << "(";
966  for (i=imp.p.item_start(); !imp.p.item_end(i); i=imp.p.item_next(i))
967  {
968  imp.p.item_prob(i,name,prob);
969  s << "(" << name << " " << prob << ") ";
970  }
971  s << imp.p.most_probable(&prob) << ")";
972  }
973  else
974  s << "([WImpurity unset])";
975 
976  return s;
977 }
978 
979 
980 
981 
INLINE const T & a_no_check(int row, int col) const
const access with no bounds check, care recommend
Definition: EST_TMatrix.h:183
EST_Litem * item_next(EST_Litem *idx) const
Used for iterating through members of the distribution.
const EST_String & most_probable(double *prob=NULL) const
Return the most probable member of the distribution.
double stddev(void) const
standard deviation of currently cummulated values
float & a(int i, int c=0)
Definition: EST_Track.cc:1022
double samples(void) const
Total number of example found.
const int Int(void) const
Definition: EST_Val.h:135
A Regular expression class to go with the CSTR EST_String class.
Definition: EST_Regex.h:56
bool init(const EST_StrList &vocab)
Initialise using given vocabulary.
int num_channels() const
return number of channels in track
Definition: EST_Track.h:656
double mean(void) const
mean of currently cummulated values
EST_String itoString(int n)
Make a EST_String object from an integer.
Definition: util_io.cc:140
const float Float(void) const
Definition: EST_Val.h:143
EST_Litem * item_start() const
Used for iterating through members of the distribution.
int item_end(EST_Litem *idx) const
Used for iterating through members of the distribution.
void resize(int n, int set=1)
Definition: EST_TVector.cc:196
void cumulate(const EST_String &s, double count=1)
Add this observation, may specify number of occurrences.
double entropy(void) const
void item_prob(EST_Litem *idx, EST_String &s, double &prob) const
During iteration returns name and probability given index.
T & nth(int n)
return the Nth value
Definition: EST_TList.h:147
double variance(void) const
variance of currently cummulated values
const EST_String & string(void) const
Definition: EST_Val.h:155
INLINE int n() const
number of items in vector.
Definition: EST_TVector.h:252
void reset(void)
reset internal values
void append(const T &item)
add item onto end of list
Definition: EST_TList.h:198
double samples(void)
number of samples in set
int matches(const char *e, int pos=0) const
Exactly match this string?
Definition: EST_String.cc:652
void resize(int n, int set=1)
resize vector
T & item(const EST_Litem *p)
Definition: EST_TList.h:141