如何解决“层conv1d的输入0与层不兼容”的错误?

请问您能帮我理解我在尝试构建的模型中遇到的错误吗?

我有训练集、验证集和测试集。训练数据的形状如下:

input_shape = train.shape[1:] #(1500,)

我使用Keras编写了以下模型:

input = Input(shape=(input_shape))# Conv1D + 全局最大池化x = layers.Conv1D(filters=32, padding="valid", activation="relu", strides=1, kernel_size=4)(input)x = layers.Conv1D(filters=32, padding="valid", activation="relu", strides=1, kernel_size=4)(x)x = layers.GlobalMaxPooling1D()(x)x = layers.Dense(128, activation="relu")(x)x = layers.Dropout(0.5)(x)predictions = layers.Dense(1,kernel_initializer='normal', name="predictions")(x)model = tf.keras.Model(input, predictions)model.compile(loss="mean squared error", optimizer="adam", metrics=[concordance_index])

我得到了以下错误:

ValueError                                Traceback (most recent call last)<ipython-input-60-59c3578104d3> in <module>()      6       7 # Conv1D + 全局最大池化----> 8 x = layers.Conv1D(filters=32, padding="valid", activation="relu", strides=1, kernel_size=4)(protein_input)      9 x = layers.Conv1D(filters=32, padding="valid", activation="relu", strides=1, kernel_size=4)(x)     10 x = layers.GlobalMaxPooling1D()(x)5 frames/usr/local/lib/python3.7/dist-packages/keras/engine/input_spec.py in assert_input_compatibility(input_spec, inputs, layer_name)    230                          ', found ndim=' + str(ndim) +    231                          '. Full shape received: ' +--> 232                          str(tuple(shape)))    233     # Check dtype.    234     if spec.dtype is not None:ValueError: 层conv1d_49的输入0与层不兼容:预期的最小维度为3,发现的维度为2。接收到的完整形状为:(None, 1500)

我的输入层是否不正确?还是因为Conv1d和最大池化层的顺序问题?


回答:

因为Conv1D层期望输入的形状为batch_shape + (steps, input_dim),所以您需要添加一个新的维度。因此:

X = tf.expand_dims(X,axis=2)print(X.shape)     # X.shape=(Samples, 1500, 1) 

然后,您的X形状变为(Samples,1500,1)

现在,让我们指定输入形状为:

input = Input(shape=(X.shape[1:])  

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注