在MATLAB中使用神经网络进行分类:获取元素属于第i类的概率

我想用MATLAB解决一个分类问题。我的数据集包含3个类别和1900个样本。每个样本由10个特征定义,我有900个样本属于类别’1’,500个属于类别’2’,500个属于类别’3’。

我尝试使用MATLAB中的标准patternnet工具来训练神经网络。我进行了不同数量的神经元测试,从1到100,但分类性能始终不佳。

因此,我查看了混淆矩阵,并注意到问题在于分类器将类别’2’和’3’混淆了。接下来我尝试创建两个神经网络:

  1. 第一个神经网络是一个2类分类器,包含类别’1’和类别’23’(类别’2’和’3’的结合)。这个初步分类对我来说准确率较高(约90%)
  2. 第二个神经网络也是一个2类分类器,仅接受类别’2’和’3’的元素作为输入。问题是这个第二个神经网络的准确率相当低,约55%。

因此,我再次遇到了提高分类准确率的困难。我想进行一些测试,看看是否能提高准确率。我的想法是查看每个元素属于特定类别的概率。接下来我想做以下之一:

  1. 尝试更改确定样本类别的阈值。例如,如果所有概率大于70%的元素确实属于类别’3’,而概率在50%到70%之间的元素通常属于类别’2’(我只是编造一些数字来解释我想测试的内容)
  2. 为难以分类的样本创建一个新的类别’4’。例如,如果概率大于70%的元素确实属于类别’3’,我将把概率小于70%的元素归为类别’4’。如果这有效,我可能会有一些“未知类别’4’”的元素,但分类为’2’或’3’的元素将具有很高的准确度

所以首先我想知道是否可以获取每个元素属于特定类别的概率,其次是否有MATLAB中的标准方法来实现我想要进行的两个测试之一。(当然,如果有人有更好的想法,我很乐意测试)抱歉描述过长,但希望我至少解释清楚了我的问题。


回答:

@[隐藏人名]。为了提供一个合适的示例(包括代码等),我必须写一个完整的答案。使用SVM进行多类分类的最简单方法是使用LibSVM。LibSVM是一个免费的SVM库(你可以从这里下载),它也可以在MATLAB环境中安装和使用。解压文件后,你会在matlab文件夹中找到安装指南和所有相关内容。

你想要做的是One-vs-All SVM方法,即你训练N个SVM(N是类别的数量),每个SVM被训练来将一个给定的类别i与所有其他类别分开(第i类将被标记为正类,所有非i类将被标记为负类)。假设TrainingSetTrainingLabelsValidationSetValidationLabels是你的数据集(它们的名称相当直观),numLabels是标签的数量(在你的情况下是3)。

你可以按以下方式训练这些SVM:

for k=1:numLabels    % k-th class positive, all the other classes are negative    LabelsRecoded(TrainingLabels==k)=1;    LabelsRecoded(TrainingLabels~=k)=-1;    model{k} = svmtrain(LabelsRecoded, TrainingSet, '-c 1 -b 1 -t 0');end

在这个代码中,'-c 1 -b 1 -t 0'是LibSVM的参数:c是正则化项(设置为1),-b 1表示你想要收集输出概率(也称为决策值),-t 0表示你使用线性核。更多信息可以在LibSVM包中的自述文件中找到。另一方面,model是一个单元数组,其中第k个元素包含用于将第k个类别与所有其他类别分开的SVM结构。

预测阶段的结构如下:

LabelsRecoded=[]; % get rid of the results stored previously in the training phasefor k=1:numLabels    # same as before, but with validation labels    LabelsRecoded(ValidationLabels==k)=1;    LabelsRecoded(ValidationLabels~=k)=-1;    [~,~,p] = svmpredict(LabelsRecoded, ValidationSet, model{k}, '-b 1');    prob(:,k) = p(:,model{k}.Label==1);end

prob中,你将拥有3列(3是类别的数量),包含第k个类别为正的概率(注意model{k}.Label==1)。现在你可以根据最大概率值来收集预测标签,如下所示:

[~,PredictedLabels] = max(prob,[],2);

现在你同时拥有了预测标签和验证标签,可以根据标准公式评估准确率。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注