我正在尝试编写不使用任何库的基于k-d树的K-NN分类代码。到目前为止,我已经能够编写k-d树的代码,但是一旦训练集形成了树,我似乎无法理解如何找到k个最近的邻居。k-d树代码如下:
#include<bits/stdc++.h> using namespace std; const int k = 2; // 2-dimensions struct Node { int point[k]; Node *left, *right; }; struct Node* newNode(int arr[]) { struct Node* temp = new Node; for (int i=0; i<k; i++) temp->point[i] = arr[i]; temp->left = temp->right = NULL; return temp; } // Inserts a new node and returns root of modified treeNode *insertRec(Node *root, int point[], unsigned depth) { if (root == NULL) return newNode(point); unsigned cd = depth % k; if (point[cd] < (root->point[cd])) root->left = insertRec(root->left, point, depth + 1); else root->right = insertRec(root->right, point, depth + 1); return root; } // Function to insert a new point with given point and return new rootNode* insert(Node *root, int point[]) { return insertRec(root, point, 0); } // driver int main() { struct Node *root = NULL; int points[][k] = {{3, 6}, {17, 15}, {13, 15}, {6, 12}, {9, 1}, {2, 7}, {10, 19}}; int n = sizeof(points)/sizeof(points[0]); for (int i=0; i<n; i++) root = insert(root, points[i]); return 0; }
回答:
首先,不要使用<bits/stdc++.h>
,那是错误的做法。
要找到k个最接近的元素,你需要以一种优先遍历最接近元素的方式遍历树。然后,如果你没有足够的元素,再去遍历那些更远的元素。
我不会在这里写出具体的代码,只提供伪代码(因为我很久以前已经构建了一个这样的实现):
list l; # 按距离排序的元素列表heap p; # 按距离排序的待遍历节点堆p.push(root)while (!p.empty()){ node = p.pop(); # 获取一个新节点 d = distance(point, node); # 计算从点到节点的最短距离 if(l.empty() or distance(point, l.back()) > d) { add(node->left); # 遍历子节点 add(node->right); l.push(points); # 添加当前节点的点 } l.pop_elements(k); # 弹出元素以保留k个}