使用scikit-learn中的tree.DecisionTreeClassifier处理维度

我正在尝试使用scikit-learn创建一个决策树,使用三维的训练数据和二维的目标数据。举个简单的例子,想象一下RGB图像。假设我的目标数据是1和0,其中1表示有人脸存在,0表示没有人脸存在。例如:

red         green        blue        face presence  1000        0001         0011        0000    0110        0110         0001        0110    0110        0110         0000        0110     

RGB数据的数组将代表训练数据,而二维数组将代表我的目标类别(有脸,无脸)。

在Python中,这些数组可能看起来像这样:

rgb = np.array([[[1,0,0,0],[0,1,1,0],[0,1,1,0]],               [[0,0,0,1],[0,1,1,0],[0,1,1,0]],               [[0,0,1,1],[0,0,0,1],[0,0,0,0]]])face = np.array([[0,0,0,0],[0,1,1,0],[0,1,1,0]])

遗憾的是,这不起作用

import numpy as npfrom sklearn import treedt_clf = tree.DecisionTreeClassifier()dt_clf = dt_clf.fit(rgb, face)

这会抛出以下错误:

发现维度为3的数组。期望<= 2

我尝试了多种方式重塑和扁平化数据,但得到了另一个错误:

标签数量=xxx 与样本数量不匹配

有谁知道我如何使用tree.DecisionTreeClassifier来实现这个目标?谢谢。


回答:

我想我已经解决了这个问题。方法不是很优雅。也许有人可以提供一些帮助来清理代码。基本上,我需要将RGB数据组织成一个包含12个3元素数组的数组,或者形状为(12,3)。例如…

np.hsplit(np.dstack(rgb).flatten(), len(face.flatten()))

我还扁平化了脸部数据,所以我的最终拟合调用变为…

dt_clf = dt_clf.fit(np.hsplit(np.dstack(rgb).flatten(), len(face.flatten())),                     face.flatten())

现在我可以测试一个新的数据集,看看它是否有效。目标图像在红色和绿色像素都显示时表示有脸存在,所以一个好的测试可能是…

red         green        blue 1100        1100         0011  1100        1100         0001  0000        0000         0000

或者…

predict = np.array([[[1,1,0,0],[1,1,0,0],[0,0,0,0]],                    [[1,1,0,0],[1,1,0,0],[0,0,0,0]],                    [[0,0,1,1],[0,0,0,1],[0,0,0,0]]])

所以…

predicted = dt_clf.predict(np.hsplit(np.dstack(predict).flatten(),                           len(face.flatten())))

然后将其恢复到正确的维度…

predicted = np.array(np.hsplit(predicted, face.shape[0]))

这将得到

array([[1, 1, 0, 0],       [1, 1, 0, 0],       [0, 0, 0, 0]])

太棒了!现在看看这是否能在更大的数据集上工作。请随时提供建议,使代码更简洁。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注