Mnist数据集模式识别准确率

我对matlab完全是新手,这是我的第一个项目。Mnist数据集有60000张从0到9的图片用于训练,1000张图片用于测试。我尝试通过计算均值为这10个类别(0到9)创建模式,然后使用欧几里得距离进行识别。这个方法非常简单,但准确率非常低。我不知道问题出在哪里,导致准确率如此低。准确率:1.73%

这是我用于寻找10个类别模式的代码:

root = 'F:\matlab\ex1\exercise-EquivaliencOfL2DistanceAndDotProduct\dataset';fn = strcat (root, '\MnistTrainX.mat');load (fn);fn = strcat (root, '\MnistTrainY.mat');load (fn);weights = zeros (10, 784);b = zeros (10, 1);im=reshape(MnistTrainX(5,:),[28 ,28]);imshow(im,[]);imshow(im',[]);for c=1 : 10    idx=find(MnistTrainY == c-1);   weights (c,:)=mean( MnistTrainX(idx,:));endtrainAccuray = ComputeInnerProductAccuracy(weights,b, MnistTrainX,MnistTrainY);display(trainAccuray);fn = strcat (root, '\MnistTestX.mat');load (fn);fn = strcat (root, '\MnistTestY.mat');load (fn);testAccuray  = ComputeInnerProductAccuracy(weights, b, MnistTestX, MnistTestY);display(testAccuray);

这是准确率计算函数:

function [acc]=ComputeInnerProductAccuracy(weights, b, X, Y)n = size(X, 1);minmat = zeros (60000, 2);endmat = zeros (60000, 10);m = size(X);a=0;for i=1 : n    for j=1 : 10        endmat(i,j)=sum((X(i,:)-(weights(j,:))).^2,2);    end    [minmat(i,1) ,minmat(i,2)]= min(endmat(i,:));    if minmat(i,2)== Y(i)        a=a+1;    endendacc=(a*100)/60000;end

回答:

你的代码基本上是正确的,尽管效率很低。我不会花时间去提高它的效率,因为有许多地方需要改进。相反,我将专注于代码中的问题。代码有两个问题。首先是当你查找哪个数字的距离最小时:

[minmat(i,1) ,minmat(i,2)]= min(endmat(i,:));

请注意,min函数的第二个输出产生的是最小值的位置,从索引1开始Y中的类别值应包含0到9,但你的min函数的输出索引是从1到10。输出最小值的索引和对应的类别值之间相差1,这可能是你准确率如此低的原因。

因此,你需要在检查最小标签是否确实是真实标签之前,从minmat(i, 2)中减去1…或者你可以简单地在检查时将Y(i)加1:

[minmat(i,1) ,minmat(i,2)]= min(endmat(i,:));if minmat(i,2)== Y(i)+1 % 更改    a=a+1;end

第二个错误是“内积”函数(实际上你在计算欧几里得距离….但我们先不讨论这个问题)假设总是有60000个输入,但你的测试集没有这么多输入。这在训练数据上运行良好,但会报告错误的测试数据准确率。确保将函数中的所有60000实例更改为n。你已经在代码中创建了这个变量,它决定了有多少输入。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注