使用VIP包在parsnip模型上计算重要性度量

我正在尝试使用vi_firm()在parsnip中创建的逻辑回归模型上计算特征重要性。我将使用iris数据集,并尝试预测一个观测是否属于setosa种类。

iris1 <- iris %>%  mutate(class  = case_when(Species == 'setosa' ~ 'setosa',                            TRUE ~ 'other'))iris1$class = as.factor(iris1$class)#set up logistic regression modeliris.lr = logistic_reg(  mode="classification",  penalty=NULL,  mixture=NULL) %>%  set_engine("glmnet")iris.fit = iris.lr %>%  fit(class ~. , data = iris1)library(vip)vip::vi_firm(iris.fit, feature_names = features, train = iris1, type = 'classification')

这会返回

错误:您是否应该使用new_data而不是newdata

我也在尝试使用相关的pdp包中的partial函数生成部分依赖图,我得到了同样的错误。


回答:

对于“glmnet”对象,正确的参数应该是s,而不是lambda,以与coef.glmnet保持一致(然而,目前用vi()调用它会因与scale参数的部分匹配而产生错误——我将在本周末推出一个修复;https://github.com/koalaverse/vip/issues/103)。此外,从版本0.2.2开始,vi_model应该可以直接与model_fit对象一起工作。所以这里正确的调用应该是:

> vi_model(iris_fit, s = iris_fit$fit$lambda[10]). ## A tibble: 4 x 3  Variable     Importance Sign   <chr>             <dbl> <chr>1 Sepal.Length      0     NEG  2 Sepal.Width       0     NEG  3 Petal.Length     -0.721 NEG  4 Petal.Width       0     NEG 

关于vi_firm()pdp::partial(),最简单的方法是创建您自己的预测包装器。每个函数的文档中应该有足够的细节,我们即将发表的论文中也有更多示例(https://github.com/koalaverse/vip/blob/master/rjournal/RJwrapper.pdf),但这里是一个基本的示例:

> # Data matrix (features only)> X <- data.matrix(subset(iris1, select = -class))> > # Prediction wrapper for partial dependence> pfun <- function(object, newdata) {+   # Return averaged prediciton for class of interest+   mean(predict(object, newx = newdata, s = iris_fit$fit$lambda[10], +        type = "link")[, 1L])+ }> > # PDP-based VI> features <- setdiff(names(iris1), "class")> vip::vi_firm(+   object = iris_fit$fit, +   feature_names = features, +   train = X, +   pred.fun = pfun+ )# A tibble: 4 x 2  Variable     Importance  <chr>             <dbl>1 Sepal.Length       0   2 Sepal.Width        0   3 Petal.Length       1.274 Petal.Width        0   > > # PDP> pd <- pdp::partial(iris_fit$fit, "Petal.Length", pred.fun = pfun, +                    train = X)> head(pd)  Petal.Length      yhat1     1.000000 1.06447562     1.140476 0.96322283     1.280952 0.86197004     1.421429 0.76071725     1.561905 0.65946446     1.702381 0.5582116

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注