使用Scikit-learn训练SVM(支持向量机)分类器

我想使用以下代码在Scikit-learn中训练不同的分类器,以解决多标签分类问题:

names = [    "Nearest Neighbors",    "Linear SVM", "RBF SVM", "Gaussian Process",    "Decision Tree", "Random Forest", "Neural Net", "AdaBoost",    "Naive Bayes", "QDA"]classifiers = [    KNeighborsClassifier(3),    SVC(C=0.025),    SVC(gamma=2, C=1),    GaussianProcessClassifier(1.0 * RBF(1.0)),    DecisionTreeClassifier(max_depth=5),    RandomForestClassifier(max_depth=5),    MLPClassifier(alpha=0.5),    AdaBoostClassifier(),    GaussianNB(),    QuadraticDiscriminantAnalysis()]for name, clf in izip(names, classifiers):    clf.fit(X_train, Y_train)    score = clf.score(X_train, Y_test)    print name, score

KNeighbors分类器运行正常,但当我尝试使用SVM分类器时,抛出了以下异常:

Traceback (most recent call last):  File "/Users/mac/PycharmProjects/GraphLstm/classifier.py", line 87, in <module>    clf.fit(X_train, Y_train)  File "/Library/Python/2.7/site-packages/sklearn/svm/base.py", line 151, in fit    X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')  File "/Library/Python/2.7/site-packages/sklearn/utils/validation.py", line 526, in check_X_y    y = column_or_1d(y, warn=True)  File "/Library/Python/2.7/site-packages/sklearn/utils/validation.py", line 562, in column_or_1d    raise ValueError("bad input shape {0}".format(shape))ValueError: bad input shape (9280, 39)

这是什么原因?如何解决这个问题?

编辑: 正如@隐藏人名所评论的,以下分类器仅适用于多标签分类

sklearn.tree.DecisionTreeClassifiersklearn.tree.ExtraTreeClassifiersklearn.ensemble.ExtraTreesClassifiersklearn.neighbors.KNeighborsClassifiersklearn.neural_network.MLPClassifiersklearn.neighbors.RadiusNeighborsClassifiersklearn.ensemble.RandomForestClassifiersklearn.linear_model.RidgeClassifierCV

回答:

由于这是一个多标签分类问题,并非所有Scikit-learn中的估计器都能天然处理它们。文档提供了支持多标签的估计器列表,如各种基于树的估计器或其他:

sklearn.tree.DecisionTreeClassifiersklearn.tree.ExtraTreeClassifiersklearn.ensemble.ExtraTreesClassifiersklearn.neighbors.KNeighborsClassifier......

然而,像一对多这样的策略可以用来训练不直接支持多标签的所需估计器。Scikit-learn估计器OneVsRestClassifier就是为此而设计的。

有关更多详细信息,请参见此处的文档

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注