在scikit-learn
中运行交叉验证时,所有分类器都会有一个工厂函数score()
,我可以轻松地检查分类器的准确性,例如来自http://scikit-learn.org/stable/modules/cross_validation.html
>>> import numpy as np>>> from sklearn import cross_validation>>> from sklearn import datasets>>> from sklearn import svm>>> iris = datasets.load_iris()>>> iris.data.shape, iris.target.shape((150, 4), (150,))>>> X_train, X_test, y_train, y_test = cross_validation.train_test_split(... iris.data, iris.target, test_size=0.4, random_state=0)>>> X_train.shape, y_train.shape((90, 4), (90,))>>> X_test.shape, y_test.shape((60, 4), (60,))>>> clf = svm.SVC(kernel='linear', C=1).fit(X_train, y_train)>>> clf.score(X_test, y_test) 0.96...
在scikit-learn
的GitHub仓库中深入研究后,我仍然无法找出clf.score()
函数的位置。
这个链接没有包含score()
,https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/svm/classes.py
sklearn
分类器的score()
函数位于哪里?
我可以轻松实现自己的评分函数,但我的目标是构建我的库,使其与sklearn
分类器保持一致,而不是自己想出评分函数 =)
回答:
scikit-learn分类器的默认score()
方法是准确率评分,并且在ClassifierMixin
类中定义。这个混入类是scikit-learn内置分类器的大多数(全部?)的父类。
如果你在编写自己的分类器,我建议你也从这个混入类和BaseEstimator
继承,这样你的模型就会自动获得评分和其他功能。