我正在使用欧几里得距离的KNN算法对简单数据进行分类。我看到一个我想做的例子,使用MATLAB的knnsearch
函数,如下所示:
load fisheriris x = meas(:,3:4);gscatter(x(:,1),x(:,2),species)newpoint = [5 1.45];[n,d] = knnsearch(x,newpoint,'k',10);line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10)
上面的代码接受一个新点,即[5 1.45]
,并找到与新点最接近的10个值。请问有人能展示一个MATLAB算法,并详细解释knnsearch
函数的作用吗?还有其他方法可以做到这一点吗?
回答:
K最近邻(KNN)算法的基础是你有一个数据矩阵,它由N
行和M
列组成,其中N
是我们拥有的数据点的数量,而M
是每个数据点的维度。例如,如果我们在数据矩阵中放置笛卡尔坐标,这通常是一个N x 2
或N x 3
的矩阵。有了这个数据矩阵,你提供一个查询点,并在数据矩阵中搜索最接近这个查询点的k
个点。
我们通常使用查询点与数据矩阵中其他点的欧几里得距离来计算我们的距离。然而,其他距离如L1或城市街区/曼哈顿距离也被使用。在此操作之后,你将拥有N
个欧几里得或曼哈顿距离,这些距离象征着查询点与数据集中每个对应点的距离。一旦找到这些距离,你只需通过将距离按升序排序,并检索那些与数据集和查询点之间距离最小的k
个点来查找最接近查询点的k
个点。
假设你的数据矩阵存储在x
中,newpoint
是一个样本点,它有M
列(即1 x M
),这是你将遵循的一般程序,按点列出:
- 查找
newpoint
与x
中每个点的欧几里得或曼哈顿距离。 - 将这些距离按升序排序。
- 返回
x
中最接近newpoint
的k
个数据点。
让我们慢慢来做每一步。
步骤#1
有人可能这样做的一种方法或许是使用for
循环,如下所示:
N = size(x,1);dists = zeros(N,1);for idx = 1 : N dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));end
如果你想实现曼哈顿距离,这将非常简单:
N = size(x,1);dists = zeros(N,1);for idx = 1 : N dists(idx) = sum(abs(x(idx,:) - newpoint));end
dists
将是一个包含x
中每个数据点与newpoint
之间距离的N
元素向量。我们对newpoint
和x
中的一个数据点进行逐元素减法,平方差异,然后使用sum
将它们全部相加。这个和然后被开方,完成欧几里得距离。对于曼哈顿距离,你将执行逐元素减法,取绝对值,然后将所有分量相加。这可能是最容易理解的实现,但对于较大规模的数据集和数据的较高维度,这可能是最低效的…
另一种可能的解决方案是复制newpoint
并使这个矩阵与x
大小相同,然后对这个矩阵进行逐元素减法,然后对每一行的所有列求和并进行开方。因此,我们可以这样做:
N = size(x, 1);dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));
对于曼哈顿距离,你将这样做:
N = size(x, 1);dists = sum(abs(x - repmat(newpoint, N, 1)), 2);
repmat
接受一个矩阵或向量,并在给定方向上重复它们一定次数。在我们的例子中,我们希望取我们的newpoint
向量,并将其堆叠N
次以创建一个N x M
矩阵,其中每一行都是M
个元素长。我们将这两个矩阵相减,然后平方每个分量。一旦我们这样做,我们对每一行的所有列求和,最后对所有结果取平方根。对于曼哈顿距离,我们进行减法,取绝对值,然后求和。
然而,在我看来,最有效的方法是使用bsxfun
。这本质上是在一个函数调用下执行我们讨论的复制。因此,代码将非常简单,如下所示:
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
对我来说,这看起来更加简洁和直接。对于曼哈顿距离,你将这样做:
dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
步骤#2
现在我们有了我们的距离,我们只需对它们进行排序。我们可以使用sort
来排序我们的距离:
[d,ind] = sort(dists);
d
将包含按升序排序的距离,而ind
告诉你未排序数组中的每个值在排序结果中的位置。我们需要使用ind
,提取这个向量的头k
个元素,然后使用ind
来索引我们的x
数据矩阵,以返回那些最接近newpoint
的点。
步骤#3
最后一步是现在返回那些最接近newpoint
的k
个数据点。我们可以非常简单地这样做:
ind_closest = ind(1:k);x_closest = x(ind_closest,:);
ind_closest
应该包含原始数据矩阵x
中最接近newpoint
的索引。具体来说,ind_closest
包含你需要从x
中采样的哪些行以获得最接近newpoint
的点。x_closest
将包含这些实际的数据点。
为了方便你复制和粘贴,以下是代码的外观:
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));%// 或者这样做以获得曼哈顿距离% dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);[d,ind] = sort(dists);ind_closest = ind(1:k);x_closest = x(ind_closest,:);
让我们通过你的例子来看看我们的代码在实际操作中的表现:
load fisheriris x = meas(:,3:4);newpoint = [5 1.45];k = 10;%// 使用欧几里得距离dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));[d,ind] = sort(dists);ind_closest = ind(1:k);x_closest = x(ind_closest,:);
通过检查ind_closest
和x_closest
,我们得到的是:
>> ind_closestind_closest = 120 53 73 134 84 77 78 51 64 87>> x_closestx_closest = 5.0000 1.5000 4.9000 1.5000 4.9000 1.5000 5.1000 1.5000 5.1000 1.6000 4.8000 1.4000 5.0000 1.7000 4.7000 1.4000 4.7000 1.4000 4.7000 1.5000
如果你运行knnsearch
,你会看到你的变量n
与ind_closest
相匹配。然而,变量d
返回的是从newpoint
到每个点x
的距离,而不是实际的数据点本身。如果你想要实际的距离,只需在我写的代码之后执行以下操作:
dist_sorted = d(1:k);
请注意,上面的答案仅在一个包含N
个例子的批次中使用了一个查询点。KNN经常被同时用于多个例子。假设我们有Q
个查询点,我们希望在KNN中测试。这将导致一个k x M x Q
矩阵,其中对于每个例子或每个切片,我们返回k
个最接近的点,其维度为M
。或者,我们可以返回k
个最接近点的ID,从而得到一个Q x k
矩阵。让我们计算两者。
一种天真的方法是将上面的代码应用在一个循环中,并遍历每个例子。
像这样做会有效,我们分配一个Q x k
矩阵,并应用基于bsxfun
的方法来设置输出矩阵的每一行为数据集中最接近的k
个点,我们将使用Fisher Iris数据集,就像我们之前所做的那样。我们还将保持与之前示例相同的维度,我将使用四个例子,因此Q = 4
和M = 2
:
%// 加载数据并创建查询点load fisheriris;x = meas(:,3:4);newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];%// 定义k和输出矩阵Q = size(newpoints, 1);M = size(x, 2);k = 10;x_closest = zeros(k, M, Q);ind_closest = zeros(Q, k);%// 遍历每个点并执行如上所见的逻辑:for ii = 1 : Q %// 获取点 newpoint = newpoints(ii, :); %// 使用欧几里得距离 dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); [d,ind] = sort(dists); %// 新 - 输出匹配的ID以及点本身 ind_closest(ii, :) = ind(1 : k).'; x_closest(:, :, ii) = x(ind_closest(ii, :), :);end