ValueError: 检查输入时出错:期望time_distributed_46_input有5个维度,但得到的数组形状为(200, 200, 3)

我在尝试使用Timedistributed层时遇到了困难。我正在尝试创建一个非常简单的模型,该模型可以读取200 x 200像素的RGB图像上写的字符。

我不断收到以下错误,但我不知道如何解决它:

ValueError: Error when checking input: expected time_distributed_46_input to have 5 dimensions, but got array with shape (200, 200, 3)

这是我的Keras代码:

num_timesteps = len(chars) # 序列长度
img_width = 200
img_height = 200
img_channels = 3
def model():
    # 定义CNN模型
    cnn = Sequential()
    cnn.add(Conv2D(64, (3,3), activation='relu', padding='same', input_shape=(img_width,img_height,img_channels)))
    cnn.add(MaxPooling2D(pool_size=(3, 3)))
    cnn.add(Flatten())
    # 定义LSTM模型
    model = Sequential()
    model.add(TimeDistributed(cnn, input_shape=(num_timesteps, img_width,img_height,img_channels)))
    model.add(LSTM(num_timesteps))
    model.add(Dense(26))
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

然后我这样拟合模型:

model().fit_generator(generator=images_generator(), steps_per_epoch=20, epochs=2)

我生成图像的方式如下:

def image_sample():
    rand_str = random_str()
    blank=Image.new("RGB", (200,200),(255,255,255))
    font = ImageFont.truetype("StatePlate.ttf", 100)
    draw = ImageDraw.Draw(blank)
    draw.text((30, 40),rand_str,(0,0,0), font=font)
    draw = ImageDraw.Draw(blank)
#     datagen = ImageDataGenerator(rotation_range=90)
#     datagen.fit(blank)
    return (np.asarray(blank), one_hot_char(rand_str))
def one_hot_char(char):
    zeros = np.zeros(len(chars))
    zeros[chars.index(char)] = 1
    return zeros
def images_generator():
    yield image_sample()

任何帮助都将不胜感激!谢谢。


回答:

目前,生成器返回的是单个图像。生成器生成的输入应该具有形状:[batch_size, num_timesteps, img_width, img_height, img_channels]

对于这个虚拟数据的一个快速修复是将np.asarray(blank)更改为np.asarray([[blank]])

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注