我能从glmnet的输出中获得glm对象吗?

我使用glmnet得到了一个L2归一化的逻辑回归模型(即岭回归模型) –

ridge.model <- glmnet(x, y, alpha = 0, family = "binomial", lambda = bestlam)

现在,我想使用10折交叉验证来找到它的测试错误率。这可以使用cv.glm来完成,如下所示 –

fit_10CV<- glm(good ~ ., family = binomial, data = winedata)fit_10CV.cv.err =cv.glm(winedata ,fit_10CV, cost1, K = 10) # 10折交叉验证

但它需要一个glm对象(在本例中为fit_10CV)。

但是,glmnet的输出是一个glmnet对象,而cv.glm无法接受。我感觉自己离目标很近却又很远,因为glmnet给了我所需的逻辑回归模型,但它没有以我可以直接输入到cv.glm中来获取10折交叉验证测试错误率的形式(即作为一个glm对象)提供给我。

任何帮助都将不胜感激!


回答:

你可以使用以下两种方法来实现这一点:

url="https://raw.githubusercontent.com/stedy/Machine-Learning-with-R-datasets/master/winequality-white.csv"winedata = read.csv(url)winedata$good = factor(ifelse(winedata$quality>6,1,0))winedata$quality = NULL

首先我们运行cv.glmnet,它只保留平均标准误:

library(caret)library(glmnet)x = model.matrix(good ~ ., family = binomial, data = winedata)cv_glmnet = cv.glmnet(x, winedata$good, family = "binomial", type.measure = "class",alpha=0,nfolds = 10)

我们收集测试的lambda值,并使用caret来进行交叉验证:

tr = trainControl(method="cv",number=10)trGrid = data.frame(lambda=cv_glmnet$lambda,alpha=0)cv_caret = train(good ~ .,data=winedata,trControl = tr,tuneGrid=trGrid,family="binomial",method="glmnet")

在caret中,他们测量准确率,而1 – 准确率就是你从cv.glmnet中得到的误分类错误率。我们将它们放在一起,你可以看到它们非常相似

library(ggplot2)df = rbind(data.frame(lambda=cv_glmnet$lambda,mean_error=cv_glmnet$cvm,method="cv.glmnet"),data.frame(lambda=cv_caret$results$lambda,mean_error=1-cv_caret$results$Accuracy,method="cv.caret"))ggplot(df,aes(x=log(lambda),y=mean_error,col=method)) + geom_point() +facet_wrap(~method) + theme_bw()

enter image description here

你可以从最佳模型的各个重抽样中获取错误,如下所示:

cv_caret$resample    Accuracy     Kappa Resample1  0.7975460 0.1987720   Fold092  0.8163265 0.3127844   Fold013  0.7918367 0.1631505   Fold024  0.8057260 0.2311448   Fold055  0.8061224 0.2777347   Fold046  0.7979592 0.1989498   Fold077  0.8081633 0.2759683   Fold108  0.8183673 0.3053742   Fold089  0.8102041 0.2474983   Fold0310 0.7979592 0.1989498   Fold06

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注