我在Python中实现了kmeans算法,代码如下。我使用了一些简单的数据来测试代码,这些数据存储在一个名为data.txt的文件中,如下所示:
2 5
3 7
-1 -2
-3 -3
5 4
4 -4
3 -7
3.5 -9
我的问题是在迭代过程中,一些聚类似乎变成了空的,即(聚类数量)< k,经过我的分析,这似乎是会发生的,但在网上搜索后,我发现没有人处理kmeans算法中的这个问题。
所以我不知道问题出在哪里?是因为我的测试数据太简单了吗?
import sysimport numpy as npfrom math import sqrt"""useage: python mykmeans.py mydata.txt k"""GAP = 2MIN_VAL = 1000000def get_distance(point1, point2): dis = sqrt(pow(point1[0] - point2[0], 2) + pow(point1[1] - point2[1], 2)) return disdef cluster_dis(centroid, cluster): dis = 0.0 for point in cluster: dis += get_distance(centroid, point) return disdef update_centroids(centroids, cluster_id, cluster): x, y = 0.0, 0.0 length = len(cluster) if length == 0: # TODO:这是我的问题吗?我们需要检查这个吗? return for item in cluster: x += item[0] y += item[1] centroids[cluster_id] = (x / length, y / length)def kmeans(data, k): assert k <= len(data) seed_ids = np.random.randint(0, len(data), k) centroids = [data[idx] for idx in seed_ids] clusters = [[] for _ in xrange(k)] cluster_idx = [-1] * len(data) pre_dis = 0 while True: for point_id, point in enumerate(data): min_distance, tmp_id = MIN_VAL, -1 for seed_id, seed in enumerate(centroids): distance = get_distance(seed, point) if distance < min_distance: min_distance = distance tmp_id = seed_id if cluster_idx[point_id] != -1: dex = clusters[cluster_idx[point_id]].index(point) del clusters[cluster_idx[point_id]][dex] clusters[tmp_id].append(point) cluster_idx[point_id] = tmp_id now_dis = 0.0 for cluster_id, cluster in enumerate(clusters): now_dis += cluster_dis(centroids[cluster_id], cluster) update_centroids(centroids, cluster_id, cluster) delta_dis = now_dis - pre_dis pre_dis = now_dis if delta_dis < GAP: break print(centroids) print(clusters) return centroids, clustersdef get_data(file_name): try: fr = open(file_name) lines = fr.read().splitlines() except IOError, e: pass finally: fr.close() data = [] for line in lines: tmp = line.split() x, y = float(tmp[0]), float(tmp[1]) data.append([x, y]) return datadef main(): args = sys.argv[1:] assert len(args) > 1 file_name, k = args[0], int(args[1]) data = get_data(file_name) kmeans(data, k)if __name__ == '__main__': main()
回答:
k-means算法可能会导致空聚类。这里有一个例子,以图示方式展示。我还复制了下面的图,以防链接有一天失效。
下面的第一张图显示了7个点的分布。初始时,3、5和6被选为聚类中心。
下面的’+’显示了第一次迭代后聚类中心的变化,相同颜色表示相应的点在同一个聚类中。
从下面的图中可以看到,经过两次迭代后,蓝色聚类变成了空的,实际上只有2个聚类,而不是初始设定的3个。
因此,空聚类可能是由于初始化和’不正确’的聚类数量造成的。你可以尝试在代码中使用不同的k
值,并多次运行程序来观察聚类结果,使其更加robust。