Keras.layers.concatenate 生成错误’

我正在尝试训练一个具有两个输入分支的CNN。这两个分支(b1, b2)将合并到一个具有256个神经元的全连接层,并设置0.25的 dropout 率。以下是我目前的代码:

batch_size, epochs = 32, 3ksize = 2l2_lambda = 0.0001### 我的第一个模型(b1)b1 = Sequential()b1.add(Conv1D(128*2, kernel_size=ksize,             activation='relu',             input_shape=( xtest.shape[1], xtest.shape[2]),             kernel_regularizer=keras.regularizers.l2(l2_lambda)))b1.add(Conv1D(128*2, kernel_size=ksize, activation='relu',kernel_regularizer=keras.regularizers.l2(l2_lambda)))b1.add(MaxPooling1D(pool_size=ksize))b1.add(Dropout(0.2))b1.add(Conv1D(128*2, kernel_size=ksize, activation='relu',kernel_regularizer=keras.regularizers.l2(l2_lambda)))b1.add(MaxPooling1D(pool_size=ksize))b1.add(Dropout(0.2))b1.add(Flatten())###我的第二个模型 (b2)b2 = Sequential()b2.add(Dense(64, input_shape = (5000,), activation='relu',kernel_regularizer=keras.regularizers.l2(l2_lambda)))b2.add(Dropout(0.1))##合并两个模型model = Sequential()model.add(concatenate([b1, b2],axis = -1))model.add(Dense(256, activation='relu', kernel_initializer='normal',kernel_regularizer=keras.regularizers.l2(l2_lambda)))model.add(Dropout(0.25))model.add(Dense(num_classes, activation='softmax'))

但当我尝试合并时,出现了以下错误:

enter image description here

我首先尝试使用以下命令:

  model.add(Merge([b1, b2], mode = 'concat'))

但我得到了错误‘ImportError: cannot import name ‘Merge”。我使用的是 Keras 2.2.2 和 Python 3.6。


回答:

你需要使用函数式API来实现你想要的功能。你可以使用Concatenate层或其等效的函数式APIconcatenate

concat = Concatenate(axis=-1)([b1.output, b2.output])# 或者你可以使用函数式API如下:#concat = concatenate([b1.output, b2.output], axis=-1)x = Dense(256, activation='relu', kernel_initializer='normal',          kernel_regularizer=keras.regularizers.l2(l2_lambda))(concat)x = Dropout(0.25)(x)output = Dense(num_classes, activation='softmax')(x)model = Model([b1.input, b2.input], [output])

请注意,我只将你模型的最后一部分转换成了函数式形式。你可以对另外两个模型b1b2做同样的事情(实际上,看起来你试图定义的架构是一个由两个合并在一起的分支组成的单一模型)。最后,使用model.summary()来查看和重新检查模型的架构。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注