Matlab调试 – 初学者级别

我在Matlab方面完全是个初学者,正在尝试编写一些机器学习算法。如果有人能帮助我调试这段代码,我将不胜感激。

function y = KNNpredict(trX,trY,K,X)   % trX是NxD,trY是Nx1,K是1x1,X是1xD   % 我们返回一个单一值'y',它是预测的类别% TODO: 编写这个函数% int[] distance = new int[N];distances = zeroes(N, 1);examples = zeroes(K, D+2);i = 0;% for(every row in trX) { // 取一个例子for row=1:N,  examples(row,:) = trX(row,:); %sum = 0.0; %for(every col in this example) { // 取这个例子的每个特征 for col=1:D,     % diff = 计算这些点之间的平方差 - (trX[row][col]-X[col])^2    diff =(trX(row,col)-X(col))^2;    sum += diff; end % for distances(row) = sqrt(sum); examples(i:D+1) = distances(row); examples(i:D+2) = trY(row:1);end % for% 根据计算出的距离对例子进行排序sortrows(examples, D+1);% for(int i = 0; i < K; K++) {% 这些是最接近的邻居pos = 0;neg = 0;res = 0;for row=1:K,    if(examples(row,D+2 == -1))        neg = neg + 1;    else        pos = pos + 1;    %disp(distances(row));    endend % forif(pos > neg)    y = 1;    return;else    y = -1;    return;endendend

非常感谢


回答:

在Matlab中处理矩阵时,通常最好避免过多的循环,而是尽可能使用向量化操作。这通常会产生更快、更简洁的代码。

在你的情况下,k-最近邻算法足够简单,并且可以很好地向量化。考虑以下实现:

function y = KNNpredict(trX, trY, K, x)    %# 实例x与每个训练实例之间的欧几里得距离    dist = sqrt( sum( bsxfun(@minus, trX, x).^2 , 2) );    %# 从小到大排序距离的索引    [~,ord] = sort(dist, 'ascend');    %# 获取K个最近邻居的标签    kTrY = trY( ord(1:min(K,end)) );    %# 多数类投票    y = mode(kTrY);end

这里有一个使用Fisher-Iris数据集测试的例子:

%# 加载数据集(数据 + 标签)load fisheririsX = meas;Y = grp2idx(species);%# 将数据划分为训练/测试集c = cvpartition(Y, 'holdout',1/3);trX = X(c.training,:);trY = Y(c.training);tsX = X(c.test,:);tsY = Y(c.test);%# 预测K = 10;pred = zeros(c.TestSize,1);for i=1:c.TestSize    pred(i) = KNNpredict(trX, trY, K, tsX(i,:));end%# 验证C = confusionmat(tsY, pred)

KNN预测的混淆矩阵,K=10时:

C =    17     0     0     0    16     0     0     1    16

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注