如何将独热编码数据传递给神经网络模型进行预测?

我是数据科学的新手,想在R中构建一个神经网络模型。我已经阅读了关于在训练前对分类数据进行独热编码的内容。我尝试实现了这一点,然而,在尝试训练模型时,我收到了以下错误:

Error in model.frame.default(formula = nndf$class ~ ., data = train) :   invalid type (list) for variable 'nndf$class'

我阅读了nnet的文档,文档中解释说公式应该按以下方式传递:

class ~ x1 + x2

但我仍然不确定如何正确地传递数据。

这是代码:

nndf$al <- one_hot(as.data.table(nndf$al))nndf$su <- one_hot(as.data.table(nndf$su))nndf$rbc <- one_hot(as.data.table(nndf$rbc))nndf$pc <- one_hot(as.data.table(nndf$pc))nndf$pcc <- one_hot(as.data.table(nndf$pcc))nndf$ba <- one_hot(as.data.table(nndf$ba))nndf$htn <- one_hot(as.data.table(nndf$htn))nndf$dm <- one_hot(as.data.table(nndf$dm))nndf$cad <- one_hot(as.data.table(nndf$cad))nndf$appet <- one_hot(as.data.table(nndf$appet))nndf$pe <- one_hot(as.data.table(nndf$pe))nndf$ane <- one_hot(as.data.table(nndf$ane))nndf$class <- one_hot(as.data.table(nndf$class))class(nndf$class)# view the dataframe to ensure one hot encoding is correctsummary(nndf)# randomly sample rows for tt splittrain_idx <- sample(1:nrow(nndf), 0.8 * nrow(nndf))test_idx <- setdiff(1:nrow(nndf), train_idx)# prepare training set and corresponding labelstrain <- nndf[train_idx,]# prepare testing set and corresponding labelsX_test <- nndf[test_idx,]y_test <- nndf[test_idx, "class"]# create model with a single hidden layer containing 500 neuronsmodel <- nnet(nndf$class~., train, maxit=150, size=10)# predictionX_pred <- predict(train, type="raw")

回答:

假设

数据集中(nndf)的所有变量都是分类变量。

步骤

  1. 将除响应变量(即class)之外的所有变量转换为独热编码(即0,1格式)

one_hot方法

  one_hot_df <- one_hot(nndf[, -13]) # 13是`class`变量的索引。

model.matrix方法

  model_mat_df <- model.matrix( ~ . - 1, nndf[, -13])
  1. class转换为因子,并将其添加到上述任一数据框中。

    class <- as.factor(nndf$class)
    final_df <- cbind(model_mat_df, class)

  2. final_df分割成训练集和测试集,并在模型中使用这些数据。

    nnet(class~., train, maxit=150, size=10)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注