scikit learn中的Meanshift算法(Python)无法理解数据类型

我有一个数据集,包含7265个样本132个特征。我想使用scikit learn中的meanshift算法,但遇到了以下错误:

Traceback (most recent call last):  File "C:\Users\OJ\Dropbox\Dt\Code\visual\facetest\facetracker_video.py", line 130, in <module>    labels, centers = getClusters(data,clusters)  File "C:\Users\OJ\Dropbox\Dt\Code\visual\facetest\facetracker_video.py", line 34, in getClusters    ms.fit(np.array(dataarray))  File "C:\python2.7\lib\site-packages\sklearn\cluster\mean_shift_.py", line 280, in fit    cluster_all=self.cluster_all)  File "C:\python2.7\lib\site-packages\sklearn\cluster\mean_shift_.py", line 137, in mean_shift    nbrs = NearestNeighbors(radius=bandwidth).fit(sorted_centers)  File "C:\python2.7\lib\site-packages\sklearn\neighbors\base.py", line 642, in fit    return self._fit(X)  File "C:\python2.7\lib\site-packages\sklearn\neighbors\base.py", line 180, in _fit    raise ValueError("data type not understood")ValueError: data type not understood

我的代码如下:

dataarray = np.array(data)bandwidth = estimate_bandwidth(dataarray, quantile=0.2, n_samples=len(dataarray))ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)ms.fit(dataarray)labels = ms.labels_cluster_centers = ms.cluster_centers_

如果我检查数据变量的类型,会看到:

print isinstance( dataarray, np.ndarray )>>> True

带宽是0.925538333061,dataarray.dtypefloat64

我使用的是scikit learn 0.14.1

我可以使用scikit中的其他算法进行聚类(尝试了kmeans和dbscan)。我做错了什么?


编辑:

数据可以在这里找到:(pickle格式):http://ojtwist.be/datatocluster.p和:http://ojtwist.be/datatocluster.npz


回答:

这是scikit项目中的一个bug。这里有文档记录

在拟合过程中存在从浮点数到整数的转换,这在某些情况下可能会导致崩溃(通过将种子点放置在bins的角落而不是中心)。链接中有一些代码可以修复这个问题。

如果你不想直接修改scikit的代码(并且保持你的代码与其他机器的兼容性),我建议你在将数据传递给MeanShift之前对其进行归一化处理。

试试这个:

>>>from sklearn import preprocessing>>>data2 = preprocessing.scale(dataarray)

然后在你的代码中使用data2。这对我来说是有效的。

如果你不想采用任何一种解决方案,这是一个为项目做出贡献的好机会,可以提交一个包含解决方案的拉取请求 🙂

编辑:你可能希望保留信息以“反归一化”meanshift的结果。因此,使用一个StandardScaler对象,而不是使用函数进行归一化处理。

祝你好运!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注