scikit learn决策树模型评估

这里是相关的代码和文档,我想知道在没有明确指定score参数的情况下,默认的cross_val_score输出的数组表示的是精确度、AUC还是其他一些指标?

使用Python 2.7和miniconda解释器。

http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

>>> from sklearn.datasets import load_iris>>> from sklearn.cross_validation import cross_val_score>>> from sklearn.tree import DecisionTreeClassifier>>> clf = DecisionTreeClassifier(random_state=0)>>> iris = load_iris()>>> cross_val_score(clf, iris.data, iris.target, cv=10)...                             ...array([ 1.     ,  0.93...,  0.86...,  0.93...,  0.93...,        0.93...,  0.93...,  1.     ,  0.93...,  1.      ])

此致,林


回答:

用户指南中得知:

默认情况下,每次交叉验证迭代计算的分数是估计器的score方法。可以通过使用scoring参数来更改这一点:

从DecisionTreeClassifier的文档中得知:

返回给定测试数据和标签的平均准确率。在多标签分类中,这是子集准确率,这是一个严格的指标,因为你需要每个样本的每个标签集都被正确预测。

不要被“平均准确率”所迷惑,这只是计算准确率的常规方式。请查看源码

    from .metrics import accuracy_score    return accuracy_score(y, self.predict(X), sample_weight=sample_weight)

现在是metrics.accuracy_score源码

def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):    ...    # 为每种可能的表示计算准确率    y_type, y_true, y_pred = _check_targets(y_true, y_pred)    if y_type.startswith('multilabel'):        differing_labels = count_nonzero(y_true - y_pred, axis=1)        score = differing_labels == 0    else:        score = y_true == y_pred    return _weighted_sum(score, sample_weight, normalize)

如果你仍然不相信:

def _weighted_sum(sample_score, sample_weight, normalize=False):    if normalize:        return np.average(sample_score, weights=sample_weight)    elif sample_weight is not None:        return np.dot(sample_score, sample_weight)    else:        return sample_score.sum()

注意:对于accuracy_score,normalize参数默认设置为True,因此它只是返回布尔numpy数组的np.average,因此它只是正确预测的平均数。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注