使用KD树查找5个最近邻

我想为蓝色点(T-SNE1)中的每个点从红色点(T-SNE2)中找到5个最近邻。我编写了这段代码只是为了找出正确的做法,但我并不确定这样做是对还是错?

X = np.random.random((10, 2))  # 10 points in 3 dimensionsY = np.random.random((10, 2))  # 10 points in 3 dimensionsNNlist=[]treex = KDTree(X, leaf_size=2)for i in range(len(Y)):    dist, ind = treex.query([Y[i]], k=5)    NNlist.append(ind[0][0])    print(ind)  # indices of 5 closest neighbors    print(dist)    print("the nearest index is:" ,ind[0][0],"with distance:",dist[0][0], "for Y",i)print(NNlist)

enter image description here输出

[[9 5 4 6 0]][[ 0.21261486  0.32859024  0.41598597  0.42960146  0.43793039]]the nearest index is: 9 with distance: 0.212614862956 for Y 0[[0 3 2 6 1]][[ 0.10907128  0.11378059  0.13984741  0.18000197  0.27475481]]the nearest index is: 0 with distance: 0.109071275144 for Y 1[[8 2 3 0 1]][[ 0.21621245  0.30543878  0.40668179  0.4370689   0.49372232]]the nearest index is: 8 with distance: 0.216212445449 for Y 2[[8 3 2 6 0]][[ 0.16648482  0.2989508   0.40967709  0.42511931  0.46589575]]the nearest index is: 8 with distance: 0.166484820786 for Y 3[[1 2 5 0 4]][[ 0.15331281  0.25121761  0.29305736  0.30173474  0.44291615]]the nearest index is: 1 with distance: 0.153312811422 for Y 4[[2 3 8 0 6]][[ 0.20441037  0.20917797  0.25121628  0.2903253   0.33914051]]the nearest index is: 2 with distance: 0.204410367254 for Y 5[[2 1 0 3 5]][[ 0.08400022  0.1484925   0.17356156  0.32387147  0.33789602]]the nearest index is: 2 with distance: 0.0840002184199 for Y 6[[8 2 3 7 0]][[ 0.2149891   0.40584999  0.50054235  0.53307269  0.5389266 ]]the nearest index is: 8 with distance: 0.21498909502 for Y 7[[1 0 2 5 9]][[ 0.07265268  0.11687068  0.19065327  0.20004392  0.30269591]]the nearest index is: 1 with distance: 0.0726526838766 for Y 8[[5 9 4 1 0]][[ 0.21563204  0.25067242  0.29904262  0.36745386  0.39634179]]the nearest index is: 5 with distance: 0.21563203953 for Y 9[9, 0, 8, 8, 1, 2, 2, 8, 1, 5]

回答:

import numpy as npfrom scipy.spatial import KDTreeX = np.random.random((10, 2))  # 10 points in 3 dimensionsY = np.random.random((10, 2))  # 10 points in 3 dimensionsNNlist=[]for i in range(len(X)):    treey = KDTree(np.concatenate([Y.tolist(), np.expand_dims(X[i], axis=0)], axis=0))    dist, ind = treey.query([X[i]], k=6)    print('index', ind)  # indices of 5 closest neighbors    print('distance', dist)    print('5 nearest neighbors')    for j in ind[0][1:]:        print(Y[j])    print()

你可以得到…

index [[10  5  8  9  1  2]]distance [[ 0.          0.3393312   0.38565112  0.40120109  0.44200758  0.47675255]]5 nearest neighbors[ 0.6298789   0.18283264][ 0.42952574  0.83918788][ 0.26258905  0.4115705 ][ 0.61789523  0.96261285][ 0.92417172  0.13276541]index [[10  1  3  8  4  9]]distance [[ 0.          0.09176157  0.18219064  0.21845335  0.28876942  0.60082231]]5 nearest neighbors[ 0.61789523  0.96261285][ 0.51031835  0.99761715][ 0.42952574  0.83918788][ 0.3744326   0.97577322][ 0.26258905  0.4115705 ]index [[10  7  0  9  5  6]]distance [[ 0.          0.15771386  0.2751765   0.3457175   0.49918935  0.70597498]]5 nearest neighbors[ 0.19803817  0.23495888][ 0.41293849  0.05585981][ 0.26258905  0.4115705 ][ 0.6298789   0.18283264][ 0.04527532  0.78806495]index [[10  0  5  7  9  2]]distance [[ 0.          0.09269963  0.20597988  0.24505542  0.31104979  0.49743673]]5 nearest neighbors[ 0.41293849  0.05585981][ 0.6298789   0.18283264][ 0.19803817  0.23495888][ 0.26258905  0.4115705 ][ 0.92417172  0.13276541]index [[10  9  5  7  0  8]]distance [[ 0.          0.20406876  0.26125464  0.30645317  0.33369641  0.45509834]]5 nearest neighbors[ 0.26258905  0.4115705 ][ 0.6298789   0.18283264][ 0.19803817  0.23495888][ 0.41293849  0.05585981][ 0.42952574  0.83918788]

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注