我有一个训练数据集,我从中随机抽样来在R中训练模型。
sampleSize <- floor(0.1 * nrow(trainingDataFrame))train_index <- sample(seq_len(nrow(trainingDataFrame)), size = sampleSize)trainDF <- trainingDataFrame[train_index,]fit <- train(dependentVariable ~ ., data=trainDF, trControl = trainControl(method = "cv",number = 10),method="lm")
然后,我使用该模型在另一个独立的测试数据集上进行predict
预测值。然而,由于我在训练集上进行随机抽样,测试数据集中某些行的因子字段偶尔会出现训练集中未遇到过的值。这导致了"factor ... has new levels"
错误。
实际上,我在多次迭代中重复这个抽样-训练-预测过程,因此每个测试数据集记录最终很可能会有一些有效的预测。因此,对于我的用例来说,在任何特定迭代中,某些记录无法被predict
预测是可以接受的。我希望不必从模型训练中排除该字段。
相反,是否可以指示predict
函数对于这些无效行返回na
或其他默认值?
回答:
这里有一个可能的解决方案。
关键是定义你自己的predict
函数,该函数比较lm
对象中factor
变量的级别与newdata
中的级别。我们只对那些factor
级别匹配的观测值进行predict
预测,并对所有其他观测值返回NA
。
我将使用mtcars
数据来演示。
-
首先,我们创建一个样本数据,包含一个数值响应变量(
mpg
)和3个分类预测变量(cyl
,gear
,carb
)。library(tidyverse)df <- mtcars %>% select(mpg, cyl, gear, carb) %>% mutate_at(vars(-mpg), as.factor)
-
然后,我们在仅包含
cyl
和gear
某些(但不是全部)因子级别的训练数据集上训练模型。df.train <- df %>% filter(cyl %in% c(4, 6) & gear %in% c(3, 4))
-
我们拟合一个简单的线性模型。
fit <- lm(mpg ~ ., data = df.train)
-
现在,我们定义一个自定义函数,该函数将
newdata
分为两部分:(1) 具有匹配factor
级别的观测值,我们可以对其进行predict
预测响应;(2) 具有“新”级别的观测值,我们返回NA
作为响应。所有分类变量的
factor
级别存储在fit$xlevels
中,作为一个list
。我们使用purrr::imap
和purrr::reduce(..., intersect)
来确定newdata
中具有匹配factor
级别的观测值的行索引。my.predict <- function(fit, newdata) { lvls <- fit$xlevels idx <- reduce(imap(lvls, ~which(newdata[, .y] %in% .x)), intersect) res <- rep(NA, nrow(newdata)) res[idx] <- predict(fit, newdata = newdata[idx, ]) return(res)}
-
我们在完整的
df
数据集上确认结果:df$pred <- my.predict(fit, df)df# mpg cyl gear carb pred#1 21.0 6 4 4 19.75#2 21.0 6 4 4 19.75#3 22.8 4 4 1 29.10#4 21.4 6 3 1 19.75#5 18.7 8 3 2 NA#6 18.1 6 3 1 19.75#7 14.3 8 3 4 NA#8 24.4 4 4 2 24.75#9 22.8 4 4 2 24.75#10 19.2 6 4 4 19.75#11 17.8 6 4 4 19.75#12 16.4 8 3 3 NA#13 17.3 8 3 3 NA#14 15.2 8 3 3 NA#15 10.4 8 3 4 NA#16 10.4 8 3 4 NA#17 14.7 8 3 4 NA#18 32.4 4 4 1 29.10#19 30.4 4 4 2 24.75#20 33.9 4 4 1 29.10#21 21.5 4 3 1 21.50#22 15.5 8 3 2 NA#23 15.2 8 3 2 NA#24 13.3 8 3 4 NA#25 19.2 8 3 2 NA#26 27.3 4 4 1 29.10#27 26.0 4 5 2 NA#28 30.4 4 5 2 NA#29 15.8 8 5 4 NA#30 19.7 6 5 6 NA#31 15.0 8 5 8 NA#32 21.4 4 4 2 24.75