如何在Keras中拆分和合并模型?

我正在尝试构建一个由两个自编码器组成的堆叠自编码器模型。我已经有了两个自编码器,但无法将它们连接起来。

这是我目前的进展

### AUTOENCODER 1 ###X_input = Input(input_shape)x = Conv2D(64, (4,1), activation='relu', padding='same')(X_input)x = Conv2D(32, (3,2), activation='relu', padding='same')(x)x = MaxPooling2D(name='encoded')(x)encoded_shape = x.shape.as_list()x = Conv2D(32, (3,2), activation='relu', padding='same')(x)x = UpSampling2D(name='up1')(x)x = Conv2D(64, (4,1), activation='relu', padding='same')(x)x = Conv2D(1, (3,3), name='decoded', padding='same')(x)ae1 = Model(X_input, x)enc_layer_ae1 = ae1.get_layer('encoded').output

### AUTOENCODER 2 ###X_input1 = Input(encoded_shape[1:])x1 = Conv2D(24, (3,3), activation='relu', padding='same')(X_input1)x1 = Conv2D(16, (2,2), activation='relu', padding='same')(x1)x1 = MaxPooling2D((2,3), name='encoded')(x1)x1 = UpSampling2D((2,3), name='up')(x1)x1 = Conv2D(16, (2,2), activation='relu', padding='same')(x1)x1 = Conv2D(24, (3,3), activation='relu', padding='same')(x1)x1 = Conv2D(32, (1,1), padding='same')(x1)ae2 = Model(X_input1, x1)enc_layer_ae2 = ae2.get_layer('encoded').output

此时,我想通过堆叠创建另一个模型

  • ae1 从第0层到encoded
  • ae2 的相同层
  • 一些额外的Dense

所以最终我的模型应该看起来像ae1_input > ae1_conv2d > ae1_conv2d > ae1_encoded > ae2_input > ae2_conv > ae2_conv > ae2_encoded > dense > softmax

我尝试过这样做

ae2_split = Model(X_input1, enc_layer_ae2)full_output = ae2_split(enc_layer_ae1)full_output = Dense(150, activation='relu')(full_output)full_output = Dense(7, activation='softmax')(full_output)full_model = Model(enc_layer_ae1.input, full_output)

但我认为这不是正确的做法。你能建议我一个正确的方法吗?

谢谢。


回答:

首先,你应该更改enc_layer_ae2层的输入。由于Keras中的层是可调用的,你可以轻松地将一个层应用到另一个层上。

enc_layer_ae1 = ae1.get_layer('encoded')enc_layer_ae2 = ae2.get_layer('encoded')enc_layer_ae2 = enc_layer_ae2(enc_layer_ae1.output)full_output = Dense(150, activation='relu')(enc_layer_ae2)full_output = Dense(7, activation='softmax')(full_output)model = Model(enc_layer_ae1.input, full_output)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注