TreeNearestNeighbors.h
Go to the documentation of this file.
1 //===========================================================================
2 /*!
3  *
4  *
5  * \brief Efficient Nearest neighbor queries.
6  *
7  *
8  *
9  * \author T. Glasmachers
10  * \date 2011
11  *
12  *
13  * \par Copyright 1995-2015 Shark Development Team
14  *
15  * <BR><HR>
16  * This file is part of Shark.
17  * <http://image.diku.dk/shark/>
18  *
19  * Shark is free software: you can redistribute it and/or modify
20  * it under the terms of the GNU Lesser General Public License as published
21  * by the Free Software Foundation, either version 3 of the License, or
22  * (at your option) any later version.
23  *
24  * Shark is distributed in the hope that it will be useful,
25  * but WITHOUT ANY WARRANTY; without even the implied warranty of
26  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27  * GNU Lesser General Public License for more details.
28  *
29  * You should have received a copy of the GNU Lesser General Public License
30  * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31  *
32  */
33 //===========================================================================
34 
35 #ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_TREENEARESTNEIGHBORS_H
36 #define SHARK_ALGORITHMS_NEARESTNEIGHBORS_TREENEARESTNEIGHBORS_H
37 
38 
39 #include <boost/intrusive/rbtree.hpp>
42 #include <shark/Data/DataView.h>
43 namespace shark {
44 
45 
46 ///
47 /// \brief Iterative nearest neighbors query.
48 ///
49 /// \par
50 /// The IterativeNNQuery class (Iterative Nearest Neighbor
51 /// Query) allows the nearest neighbors of a reference point
52 /// to be queried iteratively. Given the reference point, a
53 /// query is set up that returns the nearest neighbor first,
54 /// then the second nearest neighbor, and so on.
55 /// Thus, nearest neighbor queries are treated in an "online"
56 /// fashion. The algorithm follows the paper (generalized to
57 /// arbitrary space-partitioning trees):
58 ///
59 /// \par
60 /// Strategies for efficient incremental nearest neighbor search.
61 /// A. J. Broder. Pattern Recognition 23(1/2), pp 171-178, 1990.
62 ///
63 /// \par
64 /// The algorithm is based on traversing a BinaryTree that
65 /// partitions the space into nested cells. The triangle
66 /// inequality is applied to exclude cells from the search.
67 /// Furthermore, candidate points are cached in a queue,
68 /// such that subsequent queries profit from points that
69 /// could not be excluded this way, but that did not turn
70 /// out the be the (current) nearest neighbor.
71 ///
72 /// \par
73 /// The tree must have a bucket size of one, but leaf nodes
74 /// with multiple copies of the same point are allowed.
75 /// This means that the space partitioning must be carried
76 /// out to the finest possible scale.
77 ///
78 /// The Data must be sotred in a random access container. This means that elements
79 /// have O(1) access time. This is crucial for the performance of the tree lookup.
80 /// When data is stored in a Data<T>, a View should be chosen as template parameter.
81 template <class DataContainer>
83 {
84 public:
85  typedef typename DataContainer::value_type value_type;
88  typedef std::pair<double, std::size_t> result_type;
89 
90  /// create a new query
91  /// \param tree Underlying space-partitioning tree (this is assumed to persist for the lifetime of the query object).
92  /// \param data Container holding the stored data which is referenced by the tree
93  /// \param point Point whose nearest neighbors are to be found.
94  IterativeNNQuery(tree_type const* tree, DataContainer const& data, value_type const& point)
95  : m_data(data)
96  , m_reference(point)
97  , m_nextIndex(0)
98  , mp_trace(NULL)
99  , mep_head(NULL)
100  , m_squaredRadius(0.0)
101  , m_neighbors(0)
102  {
103  // Initialize the recursion trace: descend to the
104  // leaf covering the reference point and queue it.
105  // The parent of this leaf becomes the "head".
106  mp_trace = new TraceNode(tree, NULL, m_reference);
107  TraceNode* tn = mp_trace;
108  while (tree->hasChildren())
109  {
110  tn->createLeftNode(tree, m_data, m_reference);
111  tn->createRightNode(tree, m_data, m_reference);
112  bool left = tree->isLeft(m_reference);
113  tn = (left ? tn->mep_left : tn->mep_right);
114  tree = (left ? tree->left() : tree->right());
115  }
116  mep_head = tn->mep_parent;
117  insertIntoQueue((TraceLeaf*)tn);
118  m_squaredRadius = mp_trace->squaredRadius(m_reference);
119  }
120 
121  /// destroy the query object and its internal data structures
123  m_queue.clear();
124  delete mp_trace;
125  }
126 
127 
128  /// return the number of neighbors already found
129  std::size_t neighbors() const {
130  return m_neighbors;
131  }
132 
133  /// find and return the next nearest neighbor
134  result_type next() {
135  if (m_neighbors >= mp_trace->m_tree->size())
136  throw SHARKEXCEPTION("[IterativeNNQuery::next] no more neighbors available");
137 
138  assert(! m_queue.empty());
139 
140  // Check whether the current node has points
141  // left, or whether it should be discarded.
142  if (m_neighbors > 0){
143  TraceLeaf& q = *m_queue.begin();
144  if (m_nextIndex < q.m_tree->size()){
145  return getNextPoint(q);
146  }
147  else
148  m_queue.erase(q);
149  }
150  if (m_queue.empty() || (*m_queue.begin()).m_squaredPtDistance > m_squaredRadius){
151  // enqueue more points
152  TraceNode* tn = mep_head;
153  while (tn != NULL){
154  enqueue(tn);
155  if (tn->m_status == COMPLETE) mep_head = tn->mep_parent;
156  tn = tn->mep_parent;
157  }
158 
159  // re-compute the radius
160  m_squaredRadius = mp_trace->squaredRadius(m_reference);
161  }
162  m_nextIndex = 0;
163  ++m_neighbors;
164  return getNextPoint(*m_queue.begin());
165  }
166 
167  /// return the size of the queue,
168  /// which is a measure of the
169  /// overhead of the search
170  std::size_t queuesize() const{
171  return m_queue.size();
172  }
173 
174 private:
175 
176  /// status of a TraceNode object during the search
177  enum Status
178  {
179  NONE, // no points of this node have been queued yet
180  PARTIAL, // some of the points of this node have been queued
181  COMPLETE, // all points of this node have been queued
182  };
183 
184  /// The TraceNode class builds up a tree during
185  /// the search. This tree covers only those parts
186  /// of the space partirioning tree that need to be
187  /// traversed in order to find the next nearest
188  /// neighbor.
189  class TraceNode
190  {
191  public:
192  /// Constructor
193  TraceNode(tree_type const* tree, TraceNode* parent, value_type const& reference)
194  : m_tree(tree)
195  , m_status(NONE)
196  , mep_parent(parent)
197  , mep_left(NULL)
198  , mep_right(NULL)
199  , m_squaredDistance(tree->squaredDistanceLowerBound(reference))
200  { }
201 
202  /// Destructor
203  virtual ~TraceNode()
204  {
205  if (mep_left != NULL) delete mep_left;
206  if (mep_right != NULL) delete mep_right;
207  }
208 
209  void createLeftNode(tree_type const* tree, DataContainer const& data, value_type const& reference){
210  if (tree->left()->hasChildren())
211  mep_left = new TraceNode(tree->left(), this, reference);
212  else
213  mep_left = new TraceLeaf(tree->left(), this, data, reference);
214  }
215  void createRightNode(tree_type const* tree, DataContainer const& data, value_type const& reference){
216  if (tree->right()->hasChildren())
217  mep_right = new TraceNode(tree->right(), this, reference);
218  else
219  mep_right = new TraceLeaf(tree->right(), this, data, reference);
220  }
221 
222  /// Compute the squared distance of the area not
223  /// yet covered by the queue to the reference point.
224  /// This is also referred to as the squared "radius"
225  /// of the area covered by the queue (in fact, it is
226  /// the radius of the largest sphere around the
227  /// reference point that fits into the covered area).
228  double squaredRadius(value_type const& ref) const{
229  if (m_status == NONE) return m_squaredDistance;
230  else if (m_status == PARTIAL)
231  {
232  double l = mep_left->squaredRadius(ref);
233  double r = mep_right->squaredRadius(ref);
234  return std::min(l, r);
235  }
236  else return 1e100;
237  }
238 
239  /// node of the tree
240  tree_type const* m_tree;
241 
242  /// status of the search
243  Status m_status;
244 
245  /// parent node
246  TraceNode* mep_parent;
247 
248  /// "left" child
249  TraceNode* mep_left;
250 
251  /// "right" child
252  TraceNode* mep_right;
253 
254  /// squared distance of the box to the reference point
255  double m_squaredDistance;
256  };
257 
258  /// hook type for intrusive container
259  typedef boost::intrusive::set_base_hook<> HookType;
260 
261  /// Leaves of the three have three roles:
262  /// (1) they are tree nodes holding exactly one point
263  /// (possibly multiple copies of this point),
264  /// (2) they know the distance of their point to the
265  /// reference point,
266  /// (3) they can be added to the candidates queue.
267  class TraceLeaf : public TraceNode, public HookType
268  {
269  public:
270  /// Constructor
271  TraceLeaf(tree_type const* tree, TraceNode* parent, DataContainer const& data, value_type const& ref)
272  : TraceNode(tree, parent, ref){
273  //check whether the tree uses a differen metric than a linear one.
274  if(tree->kernel() != NULL)
275  m_squaredPtDistance = tree->kernel()->featureDistanceSqr(data[tree->index(0)], ref);
276  else
277  m_squaredPtDistance = distanceSqr(data[tree->index(0)], ref);
278  }
279 
280  /// Destructor
281  ~TraceLeaf() { }
282 
283 
284  /// Comparison by distance, ties are broken arbitrarily,
285  /// but deterministically, by tree node pointer.
286  inline bool operator < (TraceLeaf const& rhs) const{
287  if (m_squaredPtDistance == rhs.m_squaredPtDistance)
288  return (this->m_tree < rhs.m_tree);
289  else
290  return (m_squaredPtDistance < rhs.m_squaredPtDistance);
291  }
292 
293  /// Squared distance of the single point in the leaf to the reference point.
294  double m_squaredPtDistance;
295  };
296 
297  /// insert a point into the queue
298  void insertIntoQueue(TraceLeaf* leaf){
299  m_queue.insert_unique(*leaf);
300 
301  // traverse up the tree, updating the state
302  TraceNode* tn = leaf;
303  tn->m_status = COMPLETE;
304  while (true){
305  TraceNode* par = tn->mep_parent;
306  if (par == NULL) break;
307  if (par->m_status == NONE){
308  par->m_status = PARTIAL;
309  break;
310  }
311  else if (par->m_status == PARTIAL){
312  if (par->mep_left == tn){
313  if (par->mep_right->m_status == COMPLETE) par->m_status = COMPLETE;
314  else break;
315  }
316  else{
317  if (par->mep_left->m_status == COMPLETE) par->m_status = COMPLETE;
318  else break;
319  }
320  }
321  tn = par;
322  }
323  }
324 
325  result_type getNextPoint(TraceLeaf const& leaf){
326  double dist = std::sqrt(leaf.m_squaredPtDistance);
327  std::size_t index = leaf.m_tree->index(m_nextIndex);
328  ++m_nextIndex;
329  return std::make_pair(dist,index);
330  }
331 
332  /// Recursively descend the node and enqueue
333  /// all points in cells intersecting the
334  /// current bounding sphere.
335  void enqueue(TraceNode* tn){
336  // check whether this node needs to be enqueued
337  if (tn->m_status == COMPLETE) return;
338  if (! m_queue.empty() && tn->m_squaredDistance >= (*m_queue.begin()).m_squaredPtDistance) return;
339 
340  const tree_type* tree = tn->m_tree;
341  if (tree->hasChildren()){
342  // extend the tree at need
343  if (tn->mep_left == NULL){
344  tn->createLeftNode(tree,m_data,m_reference);
345  }
346  if (tn->mep_right == NULL){
347  tn->createRightNode(tree,m_data,m_reference);
348  }
349 
350  // first descend into the closer sub-tree
351  if (tree->isLeft(m_reference))
352  {
353  // left first
354  enqueue(tn->mep_left);
355  enqueue(tn->mep_right);
356  }
357  else
358  {
359  // right first
360  enqueue(tn->mep_right);
361  enqueue(tn->mep_left);
362  }
363  }
364  else
365  {
366  TraceLeaf* leaf = (TraceLeaf*)tn;
367  insertIntoQueue(leaf);
368  }
369  }
370 
371  /// the queue is a self-balancing tree of sorted entries
372  typedef boost::intrusive::rbtree<TraceLeaf> QueueType;
373 
374 
375  ///\brief Datastorage to lookup the points referenced by the space partitioning tree.
376  DataContainer const& m_data;
377 
378  /// reference point for this query
379  value_type m_reference;
380 
381  /// queue of candidates
382  QueueType m_queue;
383 
384  /// index of the next not yet returned element
385  /// of the current leaf.
386  std::size_t m_nextIndex;
387 
388  /// recursion trace tree
389  TraceNode* mp_trace;
390 
391  /// "head" of the trace tree. This is the
392  /// node containing the reference point,
393  /// but so high up in the tree that it is
394  /// not fully queued yet.
395  TraceNode* mep_head;
396 
397  /// squared radius of the "covered" area
398  double m_squaredRadius;
399 
400  /// number of neighbors already returned
401  std::size_t m_neighbors;
402 };
403 
404 
405 ///\brief Nearest Neighbors implementation using binary trees
406 ///
407 /// Returns the labels and distances of the k nearest neighbors of a point.
408 template<class InputType, class LabelType>
409 class TreeNearestNeighbors:public AbstractNearestNeighbors<InputType,LabelType>
410 {
411 private:
413 
414 public:
419 
420  TreeNearestNeighbors(Dataset const& dataset, Tree const* tree)
421  : m_dataset(dataset), m_inputs(dataset.inputs()), m_labels(dataset.labels()),mep_tree(tree)
422  { }
423 
424  ///\brief returns the k nearest neighbors of the point
425  std::vector<DistancePair> getNeighbors(BatchInputType const& patterns, std::size_t k)const{
426  std::size_t numPoints = shark::size(patterns);
427  std::vector<DistancePair> results(k*numPoints);
428  for(std::size_t p = 0; p != numPoints; ++p){
429  IterativeNNQuery<DataView<Data<InputType> const> > query(mep_tree, m_inputs, get(patterns, p));
430  //find the neighbors using the queries
431  for(std::size_t i = 0; i != k; ++i){
432  typename IterativeNNQuery<DataView<Data<InputType> const> >::result_type result = query.next();
433  results[i+p*k].key=result.first;
434  results[i+p*k].value= m_labels[result.second];
435  }
436  }
437  return results;
438  }
439 
441  return m_dataset;
442  }
443 
444 private:
445  Dataset const& m_dataset;
446  DataView<Data<InputType> const> m_inputs;
447  DataView<Data<LabelType> const> m_labels;
448  Tree const* mep_tree;
449 
450 };
451 
452 
453 }
454 #endif