scikit多标签分类:ValueError: 输入形状错误

我认为使用SGDClassifier()并设置loss='log'可以支持多标签分类,而不需要使用OneVsRestClassifier查看这里

现在,我的数据集非常大,我使用了HashingVectorizer并将其结果作为输入传递给SGDClassifier。我的目标有42048个特征。

当我按以下方式运行时:

clf.partial_fit(X_train_batch, y)

我得到:ValueError: bad input shape (300000, 42048)

我还使用了如下参数,但仍然是相同的问题。

clf.partial_fit(X_train_batch, y, classes=np.arange(42048))

SGDClassifier的文档中,它说明y : numpy array of shape [n_samples]


回答:

不,SGDClassifier并不进行多标签分类——它进行的是多类分类,这是一个不同的问题,尽管两者都通过一对多的问题简化来解决。

然后,无论是SGD还是OneVsRestClassifier.fit都不会接受y的稀疏矩阵。前者需要一个标签数组,正如你已经发现的。后者对于多标签目的,需要一个标签列表的列表,例如

y = [[1], [2, 3], [1, 3]]

表示X[0]有标签1,X[1]有标签{2,3}X[2]有标签{1,3}

Related Posts

Scikit的fit_transform、ColumnTransformer和OneHotEncoder的目的是编码分类数据,为什么它们也用于数值数据?

我在寻找机器学习的例子来学习和理解时,无意中发现了这个…

训练一个好模型时,我应该标准化哪些数据?

已关闭。 此问题不符合 Stack Overflow …

自定义损失函数中随步骤增加降低权重

我想随着步骤的增加改变施加在损失上的权重。为了实现这一…

在Bert分类中如何获取预测准确率

我在我的聊_bot项目中使用Bert分类器。我对传入的…

超参数调优的顺序

已关闭。 此问题不符合 Stack Overflow …

Keras Sequential预测总是返回相同的结果

这是一个我用来对图片进行分类(跑鞋、铅笔和书)的算法。…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注