在Keras中神经网络的维度错误

我正在尝试实现一个神经网络,其中我将一个全连接神经网络与一个卷积神经网络合并/连接。但是当我拟合模型时,我遇到了以下错误:

ValueError: 所有输入数组(x)应该具有相同数量的样本。得到的数组形状为:[(1, 100, 60, 4500), (100, 4500)]

我有两个不同的输入:

  1. 图像(尺寸:1,100,60,4500),其中1是通道,100是样本数,60*4500是我的图像的尺寸。这将输入到我的卷积神经网络

  2. 位置(尺寸:100,4500):其中100指的是样本数。

我的输出的维度是100,2。

我的神经网络的代码是:

###Convolution neural networkb1 = Sequential()b1.add(Conv2D(128*2, kernel_size=3,activation='relu',data_format='channels_first',              input_shape=(100,60,4500)))b1.add(Conv2D(128*2, kernel_size=3, activation='relu'))b1.add(Dropout(0.2))b1.add(Conv2D(128*2, kernel_size=4, activation='relu'))b1.add(Dropout(0.2))b1.add(Flatten())b1.summary()###Fully connected feed forward neural networkb2 = Sequential()b2.add(Dense(64, input_shape = (4500,), activation='relu'))b2.add(Dropout(0.1))b2.summary()model = Sequential()###Concatenating the two networksconcat = concatenate([b1.output, b2.output], axis=-1)x = Dense(256, activation='relu', kernel_initializer='normal')(concat)x = Dropout(0.25)(x)output = Dense(2, activation='softmax')(x)model = Model([b1.input, b2.input], [output])model.compile(loss=keras.losses.categorical_crossentropy,              optimizer=keras.optimizers.Adam(),              metrics=['accuracy'])history = model.fit([image, positions], Ytest, batch_size=10,              epochs=1,              verbose=1)

另外,我的’图像’数组是四维的原因是最初它只是(100,60,4500),但后来我遇到了以下错误:

ValueError: 检查输入时出错:期望conv2d_10_input具有4个维度,但得到的数组形状为(100, 60, 4500)

通过谷歌搜索,我发现它也需要通道数作为输入。在我添加了通道数后,这个错误消失了,但随后我遇到了开始时提到的另一个错误。

所以有人能告诉我如何解决这个错误(我一开始指定的那个)吗?将不胜感激。


回答:

您的样本数(批量大小)始终应该是第一个维度。因此,您的数据应该具有形状(100, 1, 60, 4500)用于图像,(100, 4500)用于位置。对于Conv2D层,channels_first参数意味着通道是第一个非批量维度。

您还需要将第一个Conv2D层的输入形状更改为(1, 60, 4500)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注