Keras Dense层错误,期望4维度但得到形状为(1024,2)的数组 [duplicate]

我正在尝试使用启用了GPU的Tensorflow后端的Keras训练一个包含3层Dense神经网络的模型。

我的数据集有400万张20×40像素的图像,我将它们放置在以它们所属类别命名的目录中。

由于数据量巨大,我无法将所有数据一次性加载到RAM中并输入到模型中,因此我想使用Keras的ImageDataGenerator,特别是flow_from_directory()函数来解决这个问题。这会生成一个(x, y)的元组,其中x是图像的numpy数组,y是图像的标签。

我期望模型能够访问numpy数组作为输入,因此我设置了输入形状为:(None,20,40,3),其中None是批量大小,20和40是图像的尺寸,3是图像的通道数。然而,这并不起作用,因为当我尝试训练模型时,我不断收到错误:ValueError: 检查目标时出错:期望dense_3具有4维度,但得到形状为(1024, 2)的数组

我知道原因是它从flow_from_directory获取了元组,我想我可以更改输入形状以匹配,但是我担心这会使我的模型失去作用,因为我将使用图像进行预测而不是预分类元组。所以我的问题是,如何让flow_from_directory将图像输入我的模型,并仅使用元组来验证其训练?我在这里是否有什么误解?

作为参考,这里是我的代码:

from keras.models import Modelfrom keras.layers import *from keras.preprocessing.image import ImageDataGeneratorfrom keras.callbacks import TensorBoard# 准备图像数据生成器。train_datagen = ImageDataGenerator()test_datagen = ImageDataGenerator()train_generator = train_datagen.flow_from_directory(    '/path/to/train_data/',    target_size=(20, 40),    batch_size=1024,)test_generator = test_datagen.flow_from_directory(    '/path/to/test_data/',    target_size=(20, 40),    batch_size=1024,)# 定义输入张量。input_t = Input(shape=(20,40,3))# 现在创建层并将输入张量传递给它。hidden_1 = Dense(units=32, activation='relu')(input_t)hidden_2 = Dense(units=16)(hidden_1)prediction = Dense(units=1)(hidden_2)# 现在将所有部分组合在一起并创建模型。model = Model(inputs=input_t, outputs=prediction)model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])# 准备Tensorboard回调并开始训练。tensorboard = TensorBoard(log_dir='./graph', histogram_freq=0, write_graph=True, write_images=True)print(test_generator)model.fit_generator(    train_generator,    steps_per_epoch=2000,    epochs=100,    validation_data=test_generator,    validation_steps=800,    callbacks=[tensorboard])# 保存训练后的模型。model.save('trained_model.h5')

回答:

您的Dense层输入形状设置错误。

Dense层期望输入形状为(None,length)。

您需要将输入重塑为向量:

imageBatch=imageBatch.reshape((imageBatch.shape[0],20*40*3))

或者使用卷积层,这类层期望输入形状为(None,nRows,nCols,nChannels),如在tensorflow中那样。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注