如何在两个不同的Keras层之间创建自定义(卷积)连接

我正在实现两个不同Keras层之间的自定义连接。神经网络的开始类似于以下内容:

model = tf.keras.Sequential()c1 = model.add(Conv2D(6, kernel_size=[5,5], strides=(stride,stride), padding="valid", input_shape=(32,32,1),                   activation = 'tanh'))s2 = model.add(AveragePooling2D(pool_size=2, strides=2, padding='valid'))

现在,s2的输出大小为14*14*6

在这里,我希望将我的自定义连接应用于卷积层c3,其输出大小为10*10*16(即,需要在大小为14*14*6的s2上应用16个滤波器,以获得10*10*16的输出)。为此,我需要使用kernal_size = 5*5filers=16stride = 1,和padding=valid

然而,并非s2的所有6个特征图都连接到c3的16个特征图。连接的解释如此处所示。

例如(如上所述链接的解释),要构建C3的第一个特征图,你需要用5×5的滤波器卷积s2的3个输入图(大小为14*14*6),这将给你3个10×10的图,这些图相加后得到你的第一个特征图,大小为10×10。

我曾在某处读到,我们需要使用函数式API来构建这个连接。

但是,我不确定如何进一步进行。有人能帮助实现这个吗?

我实现这个的初步方法如下:

from keras.models import Modelfrom keras.layers import Conv2D, Input, Concatenate, Lambda, AddinputTensor = Input(shape=(14, 14,6))stride =1group0_a = Lambda(lambda x: x[:,:,0])(inputTensor)group0_b = Lambda(lambda x: x[:,:,1])(inputTensor)group0_c = Lambda(lambda x: x[:,:,2])(inputTensor) # 提取s2的0,1,2特征图conv_group0_a = Conv2D(1, kernel_size=[5,5], strides=(stride,stride), padding="valid", activation = 'tanh')(group0_a)conv_group0_b = Conv2D(1, kernel_size=[5,5], strides=(stride,stride), padding="valid", activation = 'tanh')(group0_b)conv_group0_c = Conv2D(1, kernel_size=[5,5], strides=(stride,stride), padding="valid", activation = 'tanh')(group0_c)  # 对s2的0, 1, 2特征图分别应用卷积,使用不同的核added_0 = Add()([conv_group0_a, conv_group0_b, conv_group0_c]) # 将三个卷积结果相加,得到10*10*16中的一个# 对c3的16个神经元重复此过程,最后output_layer = Concatenate()([]) # 将它们连接起来Mymodel = Model(inputTensor,output_layer)

我想知道我的方法是否正确(我知道它不是,因为我遇到了很多错误)。所以,我需要帮助重建如上所述的自定义连接。任何帮助都将不胜感激。


回答:

上面的代码是正确的,我唯一做的更改是group0_a = Lambda(lambda x: x[:,:,0:1])(inputTensor),也就是说,我不是将x作为x[:,:,0]传递,而是作为x[:,:,0:1]传递

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注