使用rpart决策树进行数据预测

我正在使用R语言对一个名为’d’的数据框进行分类,该数据框包含如下结构的数据:

initial dataset

数据有576666行,”classLabel”列是一个包含三个级别的因子:ONE, TWO, THREE。

我正在使用rpart创建一个决策树:

fitTree = rpart(d$classLabel ~ d$tripduration + d$from_station_id + d$gender +  d$birthday)

我想对newdata中的”classLabel”进行预测:

newdata = data.frame( tripduration=c(345,244,543,311),                       from_station_id=c(60,28,100,56),                      gender=c("Male","Female","Male","Male"),                        birthday=c(1972,1955,1964,1967) ) p <- predict(fitTree, newdata)

我期望的结果是一个4行的矩阵,每行包含newdata中”classLabel”的三个可能值的概率。但我得到的结果p是一个如下所示的576666行的数据框:

enter image description here

在运行predict函数时,我还收到了以下警告:

Warning message:'newdata' had 4 rows but variables found have 576666 rows 

我哪里做错了?!


回答:

我认为问题在于:你应该在预测代码中添加”type=’class'”参数:

    predict(fitTree,newdata,type="class")

尝试以下代码。我在这个例子中使用了”iris”数据集。

    > data(iris)    > head(iris)    Sepal.Length Sepal.Width Petal.Length Petal.Width Species  1          5.1         3.5          1.4         0.2  setosa  2          4.9         3.0          1.4         0.2  setosa  3          4.7         3.2          1.3         0.2  setosa  4          4.6         3.1          1.5         0.2  setosa  5          5.0         3.6          1.4         0.2  setosa  6          5.4         3.9          1.7         0.4  setosa  # 模型拟合  > fitTree<-rpart(Species~Sepal.Length+Sepal.Width+Petal.Length+Petal.Width,iris)  # 预测 - 一行数据  > newdata<-data.frame(Sepal.Length=7,Sepal.Width=4,Petal.Length=6,Petal.Width=2)  > newdata  Sepal.Length Sepal.Width Petal.Length Petal.Width  1            7           4            6           2 # 执行预测  > predict(fitTree, newdata,type="class")     1   virginica   Levels: setosa versicolor virginica # 预测 - 多行数据 > newdata2<-data.frame(Sepal.Length=c(7,8,6,5), +                      Sepal.Width=c(4,3,2,4), +                      Petal.Length=c(6,3.4,5.6,6.3), +                      Petal.Width=c(2,3,4,2.3)) > newdata2  Sepal.Length Sepal.Width Petal.Length Petal.Width   1            7           4          6.0         2.0   2            8           3          3.4         3.0   3            6           2          5.6         4.0   4            5           4          6.3         2.3# 执行预测> predict(fitTree,newdata2,type="class")      1         2         3         4  virginica virginica virginica virginica  Levels: setosa versicolor virginica

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注