使用R语言中的pROC库无法正确获取多类ROC曲线

我的预测列包含了垃圾邮件、非垃圾邮件和无法定义的类别。我使用了集成方法中的堆叠方法来进行预测。我能够达到大约77%的准确率,我能够绘制ROC曲线,但我认为它并不正确。

集成技术代码:

# 生成一级数据集用于训练集成元学习器
predDF <- data.frame(dataTest.pred, NB_Predictions, RF_Predictions,SVM_Predictions, spam = validation$spam, stringsAsFactors = F)
# 训练集成
# 定义训练控制
set.seed(1958)
# 训练集成模型
modelStack <- caret::train(spam ~ ., data = predDF, method = "rf")

这是我绘制ROC曲线的代码:

#ROC曲线
pre<-predict(modelStack, testPredLevelOne,type='prob')
# AUC测量
modelroc = mean(  c(as.numeric(multiclass.roc(testPredLevelOne$spam, pre[,1])$auc),
    as.numeric(multiclass.roc(testPredLevelOne$spam, pre[,2])$auc),
    as.numeric(multiclass.roc(testPredLevelOne$spam, pre[,3])$auc)
  ))

我认为AUC测量值也不正确。

> modelroc
[1] 0.500903

这是我尝试的另一种绘制ROC曲线的方法。我认为这种方法也不起作用。

#方法2 ROC
predictions<-as.numeric(predict(modelStack, testPredLevelOne))
roc.multi<-multiclass.roc(testPredLevelOne$spam,predictions)
auc(roc.multi)
rs <- roc.multi[['rocs']]
plot.roc(rs[[1]])
sapply(2:length(rs),function(i) lines.roc(rs[[i]],col=i))

这是绘制的图表:

enter image description here

有人能帮帮我吗?非常感谢。


回答:

pROC库中对类别间的平均值是直接计算的。因此,您只需要运行一次multiclass.roc,并提供一个预测向量。通常您应该使用type="response"而不是probs,尽管这可能因您的模型可用的predict函数而异:

pre<-predict(modelStack, testPredLevelOne, type='response')

之后,pROC会为您计算平均值,因此您可以直接获取平均AUC:

multiclass.roc(testPredLevelOne$spam, pre)$auc

请注意,在您的案例中,这将是三个AUC的平均值:垃圾邮件对非垃圾邮件、垃圾邮件对无法定义和非垃圾邮件对无法定义。这可能与您计算的准确率值不一致。

关于绘图,您从caret的predict.train方法中获得了预测类别。通常,为了构建ROC曲线,您需要一个数值的、定量的测量结果。定性测量结果会导致ROC曲线上只有一个点,这通常不是最优的。不幸的是,caret中的type="probs"选项返回的是三个概率向量,这不是pROC支持的格式。

另一种方法是手动进行计算,选择您想要测试的正确概率列和级别:

pre<-predict(modelStack, testPredLevelOne,type='prob')
roc(testPredLevelOne$spam, pre[,X], levels = c("not spam", "spam"), plot = TRUE)
roc(testPredLevelOne$spam, pre[,X], levels = c("undefined", " spam"), plot = TRUE, add = TRUE)

最终,您应该仔细审查多类ROC分析的相关性。ROC最初是为二元分类设计的,根据我的经验,各种现有的多类扩展的相关性有些值得怀疑。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注