35 #ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_TREENEARESTNEIGHBORS_H 36 #define SHARK_ALGORITHMS_NEARESTNEIGHBORS_TREENEARESTNEIGHBORS_H 39 #include <boost/intrusive/rbtree.hpp> 81 template <
class DataContainer>
85 typedef typename DataContainer::value_type value_type;
88 typedef std::pair<double, std::size_t> result_type;
94 IterativeNNQuery(tree_type
const* tree, DataContainer
const& data, value_type
const& point)
100 , m_squaredRadius(0.0)
106 mp_trace =
new TraceNode(tree, NULL, m_reference);
107 TraceNode* tn = mp_trace;
110 tn->createLeftNode(tree, m_data, m_reference);
111 tn->createRightNode(tree, m_data, m_reference);
112 bool left = tree->
isLeft(m_reference);
113 tn = (left ? tn->mep_left : tn->mep_right);
114 tree = (left ? tree->
left() : tree->
right());
116 mep_head = tn->mep_parent;
117 insertIntoQueue((TraceLeaf*)tn);
118 m_squaredRadius = mp_trace->squaredRadius(m_reference);
135 if (m_neighbors >= mp_trace->m_tree->size())
136 throw SHARKEXCEPTION(
"[IterativeNNQuery::next] no more neighbors available");
138 assert(! m_queue.empty());
142 if (m_neighbors > 0){
143 TraceLeaf& q = *m_queue.begin();
144 if (m_nextIndex < q.m_tree->
size()){
145 return getNextPoint(q);
150 if (m_queue.empty() || (*m_queue.begin()).m_squaredPtDistance > m_squaredRadius){
152 TraceNode* tn = mep_head;
155 if (tn->m_status == COMPLETE) mep_head = tn->mep_parent;
160 m_squaredRadius = mp_trace->squaredRadius(m_reference);
164 return getNextPoint(*m_queue.begin());
171 return m_queue.size();
193 TraceNode(tree_type
const* tree, TraceNode* parent, value_type
const& reference)
205 if (mep_left != NULL)
delete mep_left;
206 if (mep_right != NULL)
delete mep_right;
209 void createLeftNode(tree_type
const* tree, DataContainer
const& data, value_type
const& reference){
211 mep_left =
new TraceNode(tree->
left(),
this, reference);
213 mep_left =
new TraceLeaf(tree->
left(),
this, data, reference);
215 void createRightNode(tree_type
const* tree, DataContainer
const& data, value_type
const& reference){
217 mep_right =
new TraceNode(tree->
right(),
this, reference);
219 mep_right =
new TraceLeaf(tree->
right(),
this, data, reference);
228 double squaredRadius(value_type
const& ref)
const{
229 if (m_status ==
NONE)
return m_squaredDistance;
230 else if (m_status == PARTIAL)
232 double l = mep_left->squaredRadius(ref);
233 double r = mep_right->squaredRadius(ref);
240 tree_type
const* m_tree;
246 TraceNode* mep_parent;
252 TraceNode* mep_right;
255 double m_squaredDistance;
259 typedef boost::intrusive::set_base_hook<> HookType;
267 class TraceLeaf :
public TraceNode,
public HookType
271 TraceLeaf(tree_type
const* tree, TraceNode* parent, DataContainer
const& data, value_type
const& ref)
272 : TraceNode(tree, parent, ref){
274 if(tree->
kernel() != NULL)
275 m_squaredPtDistance = tree->
kernel()->featureDistanceSqr(data[tree->
index(0)], ref);
286 inline bool operator < (TraceLeaf
const& rhs)
const{
287 if (m_squaredPtDistance == rhs.m_squaredPtDistance)
288 return (this->m_tree < rhs.m_tree);
290 return (m_squaredPtDistance < rhs.m_squaredPtDistance);
294 double m_squaredPtDistance;
298 void insertIntoQueue(TraceLeaf* leaf){
299 m_queue.insert_unique(*leaf);
302 TraceNode* tn = leaf;
303 tn->m_status = COMPLETE;
305 TraceNode* par = tn->mep_parent;
306 if (par == NULL)
break;
307 if (par->m_status ==
NONE){
308 par->m_status = PARTIAL;
311 else if (par->m_status == PARTIAL){
312 if (par->mep_left == tn){
313 if (par->mep_right->m_status == COMPLETE) par->m_status = COMPLETE;
317 if (par->mep_left->m_status == COMPLETE) par->m_status = COMPLETE;
325 result_type getNextPoint(TraceLeaf
const& leaf){
326 double dist = std::sqrt(leaf.m_squaredPtDistance);
327 std::size_t index = leaf.m_tree->index(m_nextIndex);
329 return std::make_pair(dist,index);
335 void enqueue(TraceNode* tn){
337 if (tn->m_status == COMPLETE)
return;
338 if (! m_queue.empty() && tn->m_squaredDistance >= (*m_queue.begin()).m_squaredPtDistance)
return;
340 const tree_type* tree = tn->m_tree;
343 if (tn->mep_left == NULL){
344 tn->createLeftNode(tree,m_data,m_reference);
346 if (tn->mep_right == NULL){
347 tn->createRightNode(tree,m_data,m_reference);
351 if (tree->
isLeft(m_reference))
354 enqueue(tn->mep_left);
355 enqueue(tn->mep_right);
360 enqueue(tn->mep_right);
361 enqueue(tn->mep_left);
366 TraceLeaf* leaf = (TraceLeaf*)tn;
367 insertIntoQueue(leaf);
372 typedef boost::intrusive::rbtree<TraceLeaf> QueueType;
376 DataContainer
const& m_data;
379 value_type m_reference;
386 std::size_t m_nextIndex;
398 double m_squaredRadius;
401 std::size_t m_neighbors;
408 template<
class InputType,
class LabelType>
421 : m_dataset(dataset), m_inputs(dataset.inputs()), m_labels(dataset.labels()),mep_tree(tree)
425 std::vector<DistancePair>
getNeighbors(BatchInputType
const& patterns, std::size_t k)
const{
427 std::vector<DistancePair> results(k*numPoints);
428 for(std::size_t p = 0; p != numPoints; ++p){
431 for(std::size_t i = 0; i != k; ++i){
433 results[i+p*k].key=result.first;
434 results[i+p*k].value= m_labels[result.second];
445 Dataset
const& m_dataset;
448 Tree
const* mep_tree;