SHOGUN  v3.2.0
MAPInference.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Copyright (C) 2013 Shell Hu
9  */
10 
14 
15 using namespace shogun;
16 
18 {
19  SG_UNSTABLE("CMAPInference::CMAPInference()", "\n");
20 
21  init();
22 }
23 
25  : CSGObject()
26 {
27  init();
28  m_fg = fg;
29 
30  REQUIRE(fg != NULL, "%s::CMAPInference(): fg cannot be NULL!\n", get_name());
31 
32  switch(inference_method)
33  {
34  case TREE_MAX_PROD:
35  m_infer_impl = new CTreeMaxProduct(fg);
36  break;
37  case LOOPY_MAX_PROD:
38  SG_ERROR("%s::CMAPInference(): LoopyMaxProduct has not been implemented!\n",
39  get_name());
40  break;
41  case LP_RELAXATION:
42  SG_ERROR("%s::CMAPInference(): LPRelaxation has not been implemented!\n",
43  get_name());
44  break;
45  case TRWS_MAX_PROD:
46  SG_ERROR("%s::CMAPInference(): TRW-S has not been implemented!\n",
47  get_name());
48  break;
49  case ITER_COND_MODE:
50  SG_ERROR("%s::CMAPInference(): ICM has not been implemented!\n",
51  get_name());
52  break;
53  case NAIVE_MEAN_FIELD:
54  SG_ERROR("%s::CMAPInference(): NaiveMeanField has not been implemented!\n",
55  get_name());
56  break;
57  case STRUCT_MEAN_FIELD:
58  SG_ERROR("%s::CMAPInference(): StructMeanField has not been implemented!\n",
59  get_name());
60  break;
61  default:
62  SG_ERROR("%s::CMAPInference(): unsupported inference method!\n",
63  get_name());
64  break;
65  }
66 
68  SG_REF(m_fg);
69 }
70 
72 {
75  SG_UNREF(m_fg);
76 }
77 
78 void CMAPInference::init()
79 {
80  SG_ADD((CSGObject**)&m_fg, "fg", "factor graph", MS_NOT_AVAILABLE);
81  SG_ADD((CSGObject**)&m_outputs, "outputs", "Structured outputs", MS_NOT_AVAILABLE);
82  SG_ADD((CSGObject**)&m_infer_impl, "infer_impl", "Inference implementation", MS_NOT_AVAILABLE);
83  SG_ADD(&m_energy, "energy", "Minimized energy", MS_NOT_AVAILABLE);
84 
85  m_outputs = NULL;
86  m_infer_impl = NULL;
87  m_fg = NULL;
88  m_energy = 0;
89 }
90 
92 {
93  SGVector<int32_t> assignment(m_fg->get_num_vars());
94  assignment.zero();
95  m_energy = m_infer_impl->inference(assignment);
96 
97  // create structured output, with default normalized hamming loss
99  SGVector<float64_t> loss_weights(m_fg->get_num_vars());
100  SGVector<float64_t>::fill_vector(loss_weights.vector, loss_weights.vlen, 1.0 / loss_weights.vlen);
101  m_outputs = new CFactorGraphObservation(assignment, loss_weights); // already ref() in constructor
102  SG_REF(m_outputs);
103 }
104 
106 {
107  SG_REF(m_outputs);
108  return m_outputs;
109 }
110 
112 {
113  return m_energy;
114 }
115 
116 //-----------------------------------------------------------------
117 
119 {
120  register_parameters();
121 }
122 
124  : CSGObject()
125 {
126  register_parameters();
127  m_fg = fg;
128 }
129 
131 {
132 }
133 
134 void CMAPInferImpl::register_parameters()
135 {
136  SG_ADD((CSGObject**)&m_fg, "fg",
137  "Factor graph pointer", MS_NOT_AVAILABLE);
138 
139  m_fg = NULL;
140 }
141 
virtual float64_t inference(SGVector< int32_t > assignment)=0
#define SG_UNREF(x)
Definition: SGRefObject.h:35
#define SG_ERROR(...)
Definition: SGIO.h:131
#define REQUIRE(x,...)
Definition: SGIO.h:208
float64_t get_energy() const
CFactorGraph * m_fg
Definition: MAPInference.h:128
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:102
CFactorGraph * m_fg
Definition: MAPInference.h:83
int32_t get_num_vars() const
double float64_t
Definition: common.h:48
Class CFactorGraphObservation is used as the structured output.
#define SG_REF(x)
Definition: SGRefObject.h:34
static void fill_vector(T *vec, int32_t len, T value)
Definition: SGVector.cpp:271
CFactorGraphObservation * m_outputs
Definition: MAPInference.h:86
virtual void inference()
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:16
Class CFactorGraph a factor graph is a structured input in general.
Definition: FactorGraph.h:25
virtual const char * get_name() const
Definition: MAPInference.h:63
CFactorGraphObservation * get_structured_outputs() const
#define SG_ADD(...)
Definition: SGObject.h:71
#define SG_UNSTABLE(func,...)
Definition: SGIO.h:134
CMAPInferImpl * m_infer_impl
Definition: MAPInference.h:92

SHOGUN Machine Learning Toolbox - Documentation