如何使用k-d树实现K-NN分类?

我正在尝试编写不使用任何库的基于k-d树的K-NN分类代码。到目前为止,我已经能够编写k-d树的代码,但是一旦训练集形成了树,我似乎无法理解如何找到k个最近的邻居。k-d树代码如下:

#include<bits/stdc++.h> using namespace std; const int k = 2; // 2-dimensions struct Node {     int point[k];     Node *left, *right; }; struct Node* newNode(int arr[]) {     struct Node* temp = new Node;     for (int i=0; i<k; i++)     temp->point[i] = arr[i];     temp->left = temp->right = NULL;     return temp; } // Inserts a new node and returns root of modified treeNode *insertRec(Node *root, int point[], unsigned depth) {     if (root == NULL)     return newNode(point);      unsigned cd = depth % k;     if (point[cd] < (root->point[cd]))         root->left = insertRec(root->left, point, depth + 1);     else        root->right = insertRec(root->right, point, depth + 1);     return root; } // Function to insert a new point with given point and return new rootNode* insert(Node *root, int point[]) {     return insertRec(root, point, 0); } // driver int main() {     struct Node *root = NULL;     int points[][k] = {{3, 6}, {17, 15}, {13, 15}, {6, 12},                     {9, 1}, {2, 7}, {10, 19}};     int n = sizeof(points)/sizeof(points[0]);     for (int i=0; i<n; i++)     root = insert(root, points[i]);     return 0; } 

回答:

首先,不要使用<bits/stdc++.h>,那是错误的做法。

要找到k个最接近的元素,你需要以一种优先遍历最接近元素的方式遍历树。然后,如果你没有足够的元素,再去遍历那些更远的元素。

我不会在这里写出具体的代码,只提供伪代码(因为我很久以前已经构建了一个这样的实现):

list l; # 按距离排序的元素列表heap p; # 按距离排序的待遍历节点堆p.push(root)while (!p.empty()){    node = p.pop(); # 获取一个新节点    d = distance(point, node); # 计算从点到节点的最短距离    if(l.empty() or distance(point, l.back()) > d)    {        add(node->left); # 遍历子节点        add(node->right);        l.push(points); # 添加当前节点的点    }    l.pop_elements(k); # 弹出元素以保留k个}

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注