使用rpart决策树进行数据预测

我正在使用R语言对一个名为’d’的数据框进行分类,该数据框包含如下结构的数据:

initial dataset

数据有576666行,”classLabel”列是一个包含三个级别的因子:ONE, TWO, THREE。

我正在使用rpart创建一个决策树:

fitTree = rpart(d$classLabel ~ d$tripduration + d$from_station_id + d$gender +  d$birthday)

我想对newdata中的”classLabel”进行预测:

newdata = data.frame( tripduration=c(345,244,543,311),                       from_station_id=c(60,28,100,56),                      gender=c("Male","Female","Male","Male"),                        birthday=c(1972,1955,1964,1967) ) p <- predict(fitTree, newdata)

我期望的结果是一个4行的矩阵,每行包含newdata中”classLabel”的三个可能值的概率。但我得到的结果p是一个如下所示的576666行的数据框:

enter image description here

在运行predict函数时,我还收到了以下警告:

Warning message:'newdata' had 4 rows but variables found have 576666 rows 

我哪里做错了?!


回答:

我认为问题在于:你应该在预测代码中添加”type=’class'”参数:

    predict(fitTree,newdata,type="class")

尝试以下代码。我在这个例子中使用了”iris”数据集。

    > data(iris)    > head(iris)    Sepal.Length Sepal.Width Petal.Length Petal.Width Species  1          5.1         3.5          1.4         0.2  setosa  2          4.9         3.0          1.4         0.2  setosa  3          4.7         3.2          1.3         0.2  setosa  4          4.6         3.1          1.5         0.2  setosa  5          5.0         3.6          1.4         0.2  setosa  6          5.4         3.9          1.7         0.4  setosa  # 模型拟合  > fitTree<-rpart(Species~Sepal.Length+Sepal.Width+Petal.Length+Petal.Width,iris)  # 预测 - 一行数据  > newdata<-data.frame(Sepal.Length=7,Sepal.Width=4,Petal.Length=6,Petal.Width=2)  > newdata  Sepal.Length Sepal.Width Petal.Length Petal.Width  1            7           4            6           2 # 执行预测  > predict(fitTree, newdata,type="class")     1   virginica   Levels: setosa versicolor virginica # 预测 - 多行数据 > newdata2<-data.frame(Sepal.Length=c(7,8,6,5), +                      Sepal.Width=c(4,3,2,4), +                      Petal.Length=c(6,3.4,5.6,6.3), +                      Petal.Width=c(2,3,4,2.3)) > newdata2  Sepal.Length Sepal.Width Petal.Length Petal.Width   1            7           4          6.0         2.0   2            8           3          3.4         3.0   3            6           2          5.6         4.0   4            5           4          6.3         2.3# 执行预测> predict(fitTree,newdata2,type="class")      1         2         3         4  virginica virginica virginica virginica  Levels: setosa versicolor virginica

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注