这个代码能识别MNIST数据集吗?(K-NN方法)

我不确定下面的代码是否能执行,因为它在“计算预测”这一步已经卡了很长时间。如果它不能工作,我应该做些什么改变呢?

import structimport matplotlib.pyplot as pltimport numpy as npimport osfrom scipy.special import expitfrom sklearn.neighbors import KNeighborsClassifierfrom sklearn.metrics import accuracy_scoreclf = KNeighborsClassifier()def load_data():    with open('train-labels-idx1-ubyte', 'rb') as labels:        magic, n = struct.unpack('>II', labels.read(8))        train_labels = np.fromfile(labels, dtype=np.uint8)    with open('train-images-idx3-ubyte', 'rb') as imgs:        magic, num, nrows, ncols = struct.unpack('>IIII', imgs.read(16))        train_images = np.fromfile(imgs, dtype=np.uint8).reshape(num, 784)    with open('t10k-labels-idx1-ubyte', 'rb') as labels:        magic, n = struct.unpack('>II', labels.read(8))        test_labels = np.fromfile(labels, dtype=np.uint8)    with open('t10k-images-idx3-ubyte', 'rb') as imgs:        magic, num, nrows, ncols = struct.unpack('>IIII', imgs.read(16))        test_images = np.fromfile(imgs, dtype=np.uint8).reshape(num, 784)    return train_images, train_labels, test_images, test_labelsdef knn(train_x, train_y, test_x, test_y):    clf.fit(train_x, train_y)    print("Compute predictions")    predicted = clf.predict(test_x)    print("Accuracy: ", accuracy_score(test_y, predicted))train_x, train_y, test_x, test_y = load_data()knn(train_x, train_y, test_x, test_y)

回答:

它在“计算预测”这一步已经卡了很长时间

我建议你先使用一小部分数据来测试代码是否正常运行,然后再用整个数据集运行。这样你可以确保代码逻辑是正确的。

一旦你测试完代码,就可以安全地使用整个数据集进行训练了。

这样做,你可以轻松判断代码运行缓慢是因为代码本身的问题,还是因为数据量太大(可能代码没有问题,但你可能会发现,对于比如10个样本,运行时间超过了你愿意等待的时间,这样你可以相应地调整——否则你处理的是一个太过神秘的黑盒子)。

话虽如此,如果代码本身没有问题但运行时间过长,我同样建议,像Soumya所说,尝试在Colab上运行。你在那里有不错的硬件,可以使用长达12小时的会话,并且你的电脑可以在此期间自由地测试其他代码!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注