我想使用以下代码在Scikit-learn中训练不同的分类器,以解决多标签分类问题:
names = [ "Nearest Neighbors", "Linear SVM", "RBF SVM", "Gaussian Process", "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", "Naive Bayes", "QDA"]classifiers = [ KNeighborsClassifier(3), SVC(C=0.025), SVC(gamma=2, C=1), GaussianProcessClassifier(1.0 * RBF(1.0)), DecisionTreeClassifier(max_depth=5), RandomForestClassifier(max_depth=5), MLPClassifier(alpha=0.5), AdaBoostClassifier(), GaussianNB(), QuadraticDiscriminantAnalysis()]for name, clf in izip(names, classifiers): clf.fit(X_train, Y_train) score = clf.score(X_train, Y_test) print name, score
KNeighbors
分类器运行正常,但当我尝试使用SVM分类器时,抛出了以下异常:
Traceback (most recent call last): File "/Users/mac/PycharmProjects/GraphLstm/classifier.py", line 87, in <module> clf.fit(X_train, Y_train) File "/Library/Python/2.7/site-packages/sklearn/svm/base.py", line 151, in fit X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr') File "/Library/Python/2.7/site-packages/sklearn/utils/validation.py", line 526, in check_X_y y = column_or_1d(y, warn=True) File "/Library/Python/2.7/site-packages/sklearn/utils/validation.py", line 562, in column_or_1d raise ValueError("bad input shape {0}".format(shape))ValueError: bad input shape (9280, 39)
这是什么原因?如何解决这个问题?
编辑: 正如@隐藏人名所评论的,以下分类器仅适用于多标签分类:
sklearn.tree.DecisionTreeClassifiersklearn.tree.ExtraTreeClassifiersklearn.ensemble.ExtraTreesClassifiersklearn.neighbors.KNeighborsClassifiersklearn.neural_network.MLPClassifiersklearn.neighbors.RadiusNeighborsClassifiersklearn.ensemble.RandomForestClassifiersklearn.linear_model.RidgeClassifierCV
回答:
由于这是一个多标签分类问题,并非所有Scikit-learn中的估计器都能天然处理它们。文档提供了支持多标签的估计器列表,如各种基于树的估计器或其他:
sklearn.tree.DecisionTreeClassifiersklearn.tree.ExtraTreeClassifiersklearn.ensemble.ExtraTreesClassifiersklearn.neighbors.KNeighborsClassifier......
然而,像一对多这样的策略可以用来训练不直接支持多标签的所需估计器。Scikit-learn估计器OneVsRestClassifier就是为此而设计的。
有关更多详细信息,请参见此处的文档。