scikit-learn在多类分类中默认使用One-Vs-Rest吗?

我在处理一个多类问题(4个类),并尝试使用Python中的scikit-learn来解决这个问题。

我看到有三种选择:

  1. 我简单地实例化一个分类器,然后用训练数据进行拟合,并用测试数据进行评估;

    classifier = svm.LinearSVC(random_state=123)classifier.fit(Xtrain, ytrain)classifier.score(Xtest, ytest)
  2. 我将实例化的分类器“封装”在OneVsRest对象中,生成一个新的分类器,用于训练和测试;

    classifier = OneVsRestClassifier(svm.LinearSVC(random_state=123))classifier.fit(Xtrain, ytrain)classifier.score(Xtest, ytest)
  3. 我将实例化的分类器“封装”在OneVsOne对象中,生成一个新的分类器,用于训练和测试。

    classifier = OneVsOneClassifier(svm.LinearSVC(random_state=123))classifier.fit(Xtrain, ytrain)classifier.score(Xtest, ytest)

我理解OneVsRest和OneVsOne之间的区别,但我无法理解在第一种情况下我没有明确选择这两种选项中的任何一种时我在做什么。在这种情况下,scikit-learn会做什么?它是否隐式地使用OneVsRest?

对此问题的任何澄清都将不胜感激。

最好的祝愿,MR

编辑:为了澄清,我对SVM的情况并不特别感兴趣。例如,RandomForest的情况如何?


回答:

更新的回答:正如评论和编辑中所澄清的,问题更多是关于sklearn的一般设置,而不是下文解释的LinearSVC的具体情况。

这里的主要区别在于,您可以使用的某些分类器具有“内置的多类分类支持”,即该算法默认情况下可以区分两个以上的类。一个例子是随机森林,或具有多个输出节点的多层感知器(MLP)。

在这些情况下,根本不需要OneVs对象,因为您已经在解决任务。事实上,使用这种策略甚至可能会降低您的性能,因为您通过仅让算法在单个二元实例之间进行选择,从而“隐藏”了潜在的相关性。

另一方面,像SVCLinearSVC这样的算法仅支持二元分类。因此,为了扩展这些(表现良好的)算法类别,我们必须依赖于从初始的多类分类任务减少到二元分类任务。

据我所知,最完整的概述可以在这里找到:这里:如果你向下滚动一点,你可以看到哪些算法是天生多类的,或者默认使用其中一种策略。
请注意,所有列出的OVO算法实际上现在默认采用OVR策略!这似乎是关于这一点的信息有些过时了。

初始回答

这个问题可以通过查看相关的scikit-learn文档轻松回答。
一般来说,Stackoverflow的期望是您至少已经进行了某种形式的自我研究,因此请考虑先查看现有文档。

multi_class : 字符串,‘ovr’或‘crammer_singer’(默认=‘ovr’)

如果y包含两个以上的类,则确定多类策略。"ovr"训练n_classes个一对剩余分类器,而"crammer_singer"优化所有类上的联合目标。虽然从理论上讲crammer_singer是一致的,但它很少在实践中使用,因为它很少能带来更好的准确性,并且计算成本更高。如果选择"crammer_singer",则会忽略选项loss、penalty和dual。

所以,很明显,它使用的是一对剩余(one-vs-rest)。

顺便说一下,“常规”的SVC也是如此。

Related Posts

在使用k近邻算法时,有没有办法获取被使用的“邻居”?

我想找到一种方法来确定在我的knn算法中实际使用了哪些…

Theano在Google Colab上无法启用GPU支持

我在尝试使用Theano库训练一个模型。由于我的电脑内…

准确性评分似乎有误

这里是代码: from sklearn.metrics…

Keras Functional API: “错误检查输入时:期望input_1具有4个维度,但得到形状为(X, Y)的数组”

我在尝试使用Keras的fit_generator来训…

如何使用sklearn.datasets.make_classification在指定范围内生成合成数据?

我想为分类问题创建合成数据。我使用了sklearn.d…

如何处理预测时不在训练集中的标签

已关闭。 此问题与编程或软件开发无关。目前不接受回答。…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注