用于视频分类的convLSTM中的通道数

我创建了一个用于对灰度视频进行分类的convLSTM,这意味着它们只有一个通道。即使我将通道数定义为1,我仍然会得到以下错误:

ValueError: Error when checking input: expected conv_lst_m2d_1_inputto have 5 dimensions, but got array with shape (128, 176, 256, 256)

128是训练数据集的大小,176*256是每帧的分辨率,256是每个视频中的帧数。

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.20, shuffle=True, random_state=0) model = Sequential()model.add(ConvLSTM2D(filters = 64, kernel_size = (3, 3), return_sequences = False, data_format = "channels_last", input_shape = (seq_len, img_height, img_width, 1)))model.add(Dropout(0.2))model.add(Flatten())model.add(Dense(256, activation="relu"))model.add(Dropout(0.3))model.add(Dense(6, activation = "softmax")) model.summary() opt = keras.optimizers.SGD(lr=0.001)model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=["accuracy"]) earlystop = EarlyStopping(patience=7)callbacks = [earlystop]history = model.fit(x = X_train, y = y_train, epochs=40, batch_size = 8 , shuffle=True, validation_split=0.2, callbacks=callbacks)

回答:

你只需要扩展数据的最后一个维度即可

batch_dim, seq_len, img_height, img_width = 3, 17, 25, 25X = np.random.uniform(0,1, (batch_dim, seq_len, img_height, img_width))y = np.random.randint(0,6, batch_dim)print(X.shape)# expand input dimensionX = X[...,np.newaxis]print(X.shape)model = Sequential()model.add(ConvLSTM2D(filters = 64, kernel_size = (3, 3), return_sequences = False,                      data_format = "channels_last",                      input_shape = (seq_len, img_height, img_width, 1)))model.add(Dropout(0.2))model.add(Flatten())model.add(Dense(256, activation="relu"))model.add(Dropout(0.3))model.add(Dense(6, activation = "softmax"))model.summary()model.predict(X).shape

Related Posts

在PyTorch中,如何训练具有两个或多个输出的模型?

output_1, output_2 = model(…

如何将数据集中字符串列转换为整数?

数据集中部分数据以字符串格式存在,我需要将它们全部映射…

VIF无截距:vifs可能不合理

我正在尝试测试我的多项式逻辑回归模型的假设是否成立或失…

如何基于训练模型预测未来的K线

假设我们使用Keras训练了一个模型,其准确率超过90…

Stylegan2-ada tfrecords – ValueError: 轴与数组不匹配,图像运行一次有效,下一次可能无效

我在使用Google Colab训练一个GAN,使用从…

如何在A100 GPU上使用Pytorch(+ cuda)?

我在尝试使用A100 GPU运行我的现有代码时遇到了以…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注