使用NLTK、scikit-learn和OneVsRestClassifier开启多标签分类

免责声明:我对AI、Python、NLTK和scikit-learn还比较新手。

我正在尝试训练一个分类器,将一组文档分类到一组标签中。

我使用NLTK的包装器来与scikit-learn的OneVsRestClassifier进行交互。

training_set = [    [{"car": True, ...}, "Label 1"],    [{"car": False, ...}, "Label 2"],    ...    [{"car": False, ...}, "Label 1"],]ovr = SklearnClassifier(OneVsRestClassifier(MultinomialNB()))ovr.train(training_set)

这在多类分类中运行良好,分类器尝试将文档仅分类到一个标签。准确率还可以,但我希望分类器能够为文档分配0个、1个或多个标签。我该怎么做呢?

遗憾的是,我无法通过初始化分类器来告诉它成为一个多标签分类器,文档中说:

这种策略也可以用于多标签学习,其中分类器用于预测多个标签的实例,通过在第i个样本具有第j个标签时,单元格[i, j]为1,否则为0的2维矩阵上进行拟合。

由于我不熟悉这种语言,这对我来说不是很清楚。我感觉我需要以某种方式调整我的训练集,使分类器能够理解我想让它对我的数据进行多标签分类?如果是的,如何做呢?

我尝试以数组的形式提供标签,像这样:

training_set = [    [{"car": True, ...}, ["Label 1"]],    [{"car": False, ...}, ["Label 2"]],    ...    [{"car": False, ...}, ["Label 1"]],]

这没有按预期工作,并引发了以下错误:

DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().  y = column_or_1d(y, warn=True)One-vs-rest accuracy percent: 0.0

回答:

我通过去掉NLTK到scikit-learn的适配器,并导入一个NLTK模块来帮助我将数据结构转换为可以输入到scikit-learn的OneVsRestClassifier中,从而解决了这个问题。

from nltk import compatfrom sklearn.feature_extraction import DictVectorizerfrom sklearn.naive_bayes import MultinomialNBfrom sklearn.multiclass import OneVsRestClassifier_vectorizer = DictVectorizer(dtype=float, sparse=True)def prepare_scikit_x_and_y(labeled_featuresets):    X, y = list(compat.izip(*labeled_featuresets))    X = _vectorizer.fit_transform(X)    set_of_labels = []    for label in y:        set_of_labels.append(set(label))    y = self.mlb.fit_transform(set_of_labels)    return X, ydef train_classifier(labeled_featuresets):    X, y = prepare_scikit_x_and_y(labeled_featuresets)    classifier.fit(X, y)training_set = [    [{"car": True, ...}, ["Label 1"]],    [{"car": False, ...}, ["Label 2"]],    ...    [{"car": False, ...}, ["Label 1"]],]ovr = OneVsRestClassifier(MultinomialNB())ovr.train(training_set)

开心豆

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注