更改随机森林分类器的阈值

我需要开发一个模型,尽量避免或减少假阴性值。为此,我绘制了召回率-精确度曲线,并确定阈值应设为0.11。

我的问题是,如何在模型训练时定义阈值?在评估后再定义阈值没有意义,因为它不会反映在新数据上。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=101)rfc_model = RandomForestClassifier(random_state=101)rfc_model.fit(X_train, y_train)rfc_preds = rfc_model.predict(X_test)recall_precision_vals = []for val in np.linspace(0, 1, 101):    predicted_proba = rfc_model.predict_proba(X_test)    predicted = (predicted_proba[:, 1] >= val).astype('int')        recall_sc = recall_score(y_test, predicted)    precis_sc = precision_score(y_test, predicted)    recall_precision_vals.append({        'Threshold': val,        'Recall val': recall_sc,        'Precis val': precis_sc    })recall_prec_df = pd.DataFrame(recall_precision_vals)

有什么建议吗?


回答:

如何在模型训练时定义阈值?

在模型训练过程中根本没有阈值;随机森林是一种概率分类器,它只输出类别概率。需要阈值的“硬”类别(即0/1)在模型训练的任何阶段既不产生也不使用——仅在预测时使用,而且仅在我们确实需要硬分类的情况下才使用(并非总是如此)。请参阅预测类别还是类别概率?以获取更多详情。

实际上,scikit-learn实现的RF根本不使用阈值,即使是硬类别预测;仔细阅读predict方法的文档

预测的类别是树中平均概率估计最高的那个

简单来说,这意味着实际的RF输出是[p0, p1](假设是二元分类),predict方法简单地返回值最高的类别,即如果p0 > p1则返回0,否则返回1。

假设你实际上想要做的是如果p1大于某个小于0.5的阈值时返回1,你必须放弃predict,改用predict_proba,然后操作这些返回的概率以得到你想要的结果。这里有一个使用虚拟数据的示例:

from sklearn.ensemble import RandomForestClassifierfrom sklearn.datasets import make_classificationX, y = make_classification(n_samples=1000, n_features=4,                          n_informative=2, n_redundant=0,                           n_classes=2, random_state=0, shuffle=False)clf = RandomForestClassifier(n_estimators=100, max_depth=2,                            random_state=0)clf.fit(X, y)

在这里,简单地对X的第一个元素使用predict将返回0:

clf.predict(X)[0] # 0

因为

clf.predict_proba(X)[0]# array([0.85266881, 0.14733119])

p0 > p1

为了得到你想要的结果(即在这里返回类别1,因为p1 > threshold,阈值为0.11),你需要做的是:

prob_preds = clf.predict_proba(X)threshold = 0.11 # 定义阈值在这里preds = [1 if prob_preds[i][1]> threshold else 0 for i in range(len(prob_preds))]

之后,很容易看出现在对于第一个预测样本我们有:

preds[0]# 1

因为如上所示,对于这个样本我们有p1 = 0.14733119 > threshold

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注