更好的方法来连接ConvLSTM2D模型和表格数据模型

我已经构建了一个模型,它以时间序列的3张图像和5个数值信息作为输入,并生成时间序列的接下来三张图像。我通过以下步骤实现了这一点:

  1. 为处理图像构建一个ConvLSTM2D模型(与Keras文档中列出的示例非常相似,在此)。输入大小=(3x128x128x3)
  2. 为表格数据构建一个简单的模型,包含几个Dense层。输入大小=(1,5)
  3. 将这两个模型连接起来
  4. 使用一个Conv3D模型生成接下来的3张图像

LSTM模型的输出大小为393216(3x128x128x8)。现在我不得不将表格模型的输出设置为49,152,这样我就可以在下一层中得到442368(3x128x128x9)的大小输入。因此,这种对表格模型Dense层的无谓膨胀使得原本高效的LSTM模型表现得非常糟糕。

有没有更好的方法来连接这两个模型?我能否只让表格模型的Dense层的输出为10?

模型如下:

x_input = Input(shape=(None, 128, 128, 3))x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)x = BatchNormalization()(x)x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)x = BatchNormalization()(x)x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)x = BatchNormalization()(x)x = Flatten()(x)# x = MaxPooling3D()(x)x_tab_input = Input(shape=(5))x_tab = Dense(100, activation="relu")(x_tab_input)x_tab = Dense(49152, activation="relu")(x_tab)x_tab = Flatten()(x_tab)concat = Concatenate()([x, x_tab])output = Reshape((3,128,128,9))(concat)output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)model = Model([x_input, x_tab_input], output)model.compile(loss='mae', optimizer='rmsprop')

模型摘要:

Model: "functional_3"______________________________________________________________________________________________________________________________________________________Layer (type)                                     Output Shape                     Param #           Connected to                                      ======================================================================================================================================================input_4 (InputLayer)                             [(None, None, 128, 128, 3)]      0                                                                   ______________________________________________________________________________________________________________________________________________________conv_lst_m2d_9 (ConvLSTM2D)                      (None, None, 128, 128, 32)       40448             input_4[0][0]                                     ______________________________________________________________________________________________________________________________________________________batch_normalization_9 (BatchNormalization)       (None, None, 128, 128, 32)       128               conv_lst_m2d_9[0][0]                              ______________________________________________________________________________________________________________________________________________________conv_lst_m2d_10 (ConvLSTM2D)                     (None, None, 128, 128, 16)       27712             batch_normalization_9[0][0]                       ______________________________________________________________________________________________________________________________________________________batch_normalization_10 (BatchNormalization)      (None, None, 128, 128, 16)       64                conv_lst_m2d_10[0][0]                             ______________________________________________________________________________________________________________________________________________________input_5 (InputLayer)                             [(None, 5)]                      0                                                                   ______________________________________________________________________________________________________________________________________________________conv_lst_m2d_11 (ConvLSTM2D)                     (None, None, 128, 128, 8)        6944              batch_normalization_10[0][0]                      ______________________________________________________________________________________________________________________________________________________dense (Dense)                                    (None, 100)                      600               input_5[0][0]                                     ______________________________________________________________________________________________________________________________________________________batch_normalization_11 (BatchNormalization)      (None, None, 128, 128, 8)        32                conv_lst_m2d_11[0][0]                             ______________________________________________________________________________________________________________________________________________________dense_1 (Dense)                                  (None, 49152)                    4964352           dense[0][0]                                       ______________________________________________________________________________________________________________________________________________________flatten_3 (Flatten)                              (None, None)                     0                 batch_normalization_11[0][0]                      ______________________________________________________________________________________________________________________________________________________flatten_4 (Flatten)                              (None, 49152)                    0                 dense_1[0][0]                                     ______________________________________________________________________________________________________________________________________________________concatenate (Concatenate)                        (None, None)                     0                 flatten_3[0][0]                                                                                                                                       flatten_4[0][0]                                   ______________________________________________________________________________________________________________________________________________________reshape_2 (Reshape)                              (None, 3, 128, 128, 9)           0                 concatenate[0][0]                                 ______________________________________________________________________________________________________________________________________________________conv3d_2 (Conv3D)                                (None, 3, 128, 128, 3)           732               reshape_2[0][0]                                   ======================================================================================================================================================Total params: 5,041,012Trainable params: 5,040,900Non-trainable params: 112______________________________________________________________________________________________________________________________________________________

回答:

我同意你的观点,巨大的Dense层(拥有数百万参数)可能会阻碍模型的性能。与其用Dense层来“膨胀”表格数据,你可以选择以下两种方法之一。


选项1: 平铺x_tab张量,使其匹配你所需的形状。这可以通过以下步骤实现:

首先,无需展平ConvLSTM2D的编码张量:

x_input = Input(shape=(3, 128, 128, 3))x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)x = BatchNormalization()(x)x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)x = BatchNormalization()(x)x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)x = BatchNormalization()(x)  # Shape=(None, None, 128, 128, 8) # 注释掉:x = Flatten()(x)

其次,你可以用一个或多个Dense层处理你的表格数据。例如:

dim = 10x_tab_input = Input(shape=(5))x_tab = Dense(100, activation="relu")(x_tab_input)x_tab = Dense(dim, activation="relu")(x_tab)# x_tab = Flatten()(x_tab)  # 注意:展平2D张量不会改变张量

第三,我们将tensorflow操作tf.tile包装在Lambda层中,有效地创建x_tab张量的副本,使其匹配所需的形状:

def repeat_tabular(x_tab):    h = x_tab[:, None, None, None, :]  # Shape=(bs, 1, 1, 1, dim)    h = tf.tile(h, [1, 3, 128, 128, 1])  # Shape=(bs, 3, 128, 128, dim)    return hx_tab = Lambda(repeat_tabular)(x_tab)

最后,我们沿最后一个轴连接x和平铺的x_tab张量(你也可以考虑沿第一个轴连接,对应于通道维度)

concat = Concatenate(axis=-1)([x, x_tab])  # Shape=(3,128,128,8+dim)output = concatoutput = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)# ...

请注意,这个解决方案可能有点简单,因为模型没有将输入的图像序列编码成低维表示,限制了网络的感受野,可能会导致性能下降。


选项2: 类似于自动编码器和U-Net,将你的图像序列编码成低维表示可能是有利的,这样可以丢弃不需要的变化(例如噪声),同时保留有意义的信号(例如推断序列的接下来的3张图像所需的信号)。这可以通过以下方式实现:

首先,将输入的图像序列编码成低维的2维张量。例如,类似于以下内容:

x_input = Input(shape=(None, 128, 128, 3))x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)x = BatchNormalization()(x)x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)x = BatchNormalization()(x)x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2, return_sequences=False)(x)x = BatchNormalization()(x)x = Flatten()(x)x = Dense(64, activation='relu')(x)

请注意,最后一个ConvLSTM2D不返回序列。你可能想探索不同的编码器来达到这一点(例如,你也可以在这里使用池化层)。

其次,用Dense层处理你的表格数据。例如:

dim = 10x_tab_input = Input(shape=(5))x_tab = Dense(100, activation="relu")(x_tab_input)x_tab = Dense(dim, activation="relu")(x_tab)

第三,连接来自前两个流的数据:

concat = Concatenate(axis=-1)([x, x_tab])

第四,使用Dense + Reshape层将连接的向量投影成一系列低分辨率图像:

h = Dense(3 * 32 * 32 * 3)(concat)output = Reshape((3, 32, 32, 3))(h)

output的形状允许将图像上采样到(128, 128, 3)的形状,但它本身是任意的(例如,你可能也想在这里进行实验)。

最后,应用一个或多个Conv3DTranspose层以获得所需的输出(例如,3张(128, 128, 3)形状的图像)。

output = tf.keras.layers.Conv3DTranspose(filters=50, kernel_size=(3, 3, 3),                                         strides=(1, 2, 2), padding='same',                                         activation='relu')(output)output = tf.keras.layers.Conv3DTranspose(filters=3, kernel_size=(3, 3, 3),                                         strides=(1, 2, 2), padding='same',                                         activation='relu')(output)  # Shape=(None, 3, 128, 128, 3)

转置卷积层的原理在这里讨论。本质上,Conv3DTranspose层与普通卷积的方向相反——它允许将你的低分辨率图像上采样成高分辨率图像。

Related Posts

如何对SVC进行超参数调优?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

如何在初始训练后向模型添加训练数据?

我想在我的scikit-learn模型已经训练完成后再…

使用Google Cloud Function并行运行带有不同用户参数的相同训练作业

我正在寻找一种方法来并行运行带有不同用户参数的相同训练…

加载Keras模型,TypeError: ‘module’ object is not callable

我已经在StackOverflow上搜索并阅读了文档,…

在计算KNN填补方法中特定列中NaN值的”距离平均值”时

当我从头开始实现KNN填补方法来处理缺失数据时,我遇到…

使用巨大的S3 CSV文件或直接从预处理的关系型或NoSQL数据库获取数据的机器学习训练/测试工作

已关闭。此问题需要更多细节或更清晰的说明。目前不接受回…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注