Keras Python 多图像输入形状错误

我正在尝试自学构建一个可以接受多个图像作为输入的卷积神经网络。由于我创建的用于测试的数据集很大,而且我希望长期解决涉及非常大数据集的问题,我使用了一个生成器来将图像读取到数组中,并将这些数组传递给Keras模型的fit_generator函数。

当我单独运行我的生成器时,它运行正常,并且产生了适当形状的输出。它输出的元组包含两个条目,第一个条目的形状为(4, 100, 100, 1),第二个条目的形状为(4, )

阅读关于多输入Keras CNN的资料后,我认为这是一个用于识别四个输入中的哪一个包含图像的四输入CNN的正确生成器格式。

然而,当我运行代码时,我得到了以下错误:

"ValueError: Error when checking input: expected input_121 to have 4 dimensions, but got array with shape (100, 100, 1)"

我已经寻找解决方案有一段时间了,我怀疑问题在于如何将我的(100, 100, 1)形状的数组发送到输入作为(None, 100, 100, 1)形状的数组。

但当我尝试修改生成器的输出时,我得到了关于维度为5的错误,这是一个合理的错误,因为生成器的输出应该具有X, y = [X1, X2, X3, X4], [a, b, c, d]的形式,其中Xn的形状为(100, 100, 1),a/b/c/d是数字。

这是我的代码:

https://gist.github.com/anonymous/d283494aee982fbc30f3b52f2a6f422c

提前感谢!


回答:

你在生成器中创建的数组列表的维度是错误的。

如果你想要正确的形状,重塑单个图像以具有4个维度:(n_samples, x_size, y_size, n_bands),你的模型将能正常工作。在你的情况下,你应该将图像重塑为(1, 100, 100, 1)

最后,使用np.vstack将它们堆叠起来。生成器将输出一个形状为(4, 100, 100, 1)的数组。

检查这个修改后的代码是否有效

def input_generator(folder, directories):    Streams = []    for i in range(len(directories)):        Streams.append(os.listdir(folder + "/" + directories[i]))        for j in range(len(Streams[i])):            Streams[i][j] = "Stream" + str(i + 1) + "/" + Streams[i][j]           Streams[i].sort()    length = len(Streams[0])    index = 0    while True:        X = []        y = np.zeros(4)        for Stream in Streams:            image = load_img(folder + '/' + Stream[index], grayscale = True)            array = img_to_array(image).reshape((1,100,100,1))            X.append(array)        y[int(Stream[index][15]) - 1] = 1        index += 1        index = index % length        yield np.vstack(X), y

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注