可分离卷积层中的layer.get_weights()函数的含义是什么?

我了解到我们可以使用layer.get_weights()函数来获取层的权重和偏置。这将返回一个长度为2的列表。层的权重存储在layer.get_weights()[0]中,而偏置存储在layer.get_weights()[1]中(如果在定义层时没有禁用偏置)。这对于普通的卷积层是成立的。

我最近在EfficientDet模型中使用了可分离卷积层作为其中一层。

layers.SeparableConv2D(num_channels, kernel_size=kernel_size, strides=strides, padding='same',                            use_bias=True, name=str(name)+"/conv")

当我尝试使用相同的layer.get_weights()函数时,它返回了一个长度为3的列表,而我期望它是2,与上面相同。对此,我对列表中的三个值有些困惑。任何帮助和建议都将不胜感激。


回答:

SeparableConv2D层计算的是深度可分离卷积,与普通卷积不同,它需要2个核(2个权重张量)。无需过多细节,它使用第一个核来计算深度卷积,应用此操作后,它使用第二个核来计算点卷积。这样做的主要目的是减少参数数量,从而减少计算量。

这是一个简单的例子。假设我们有一张28x28x3的输入图像(宽度、高度、通道数),我们应用普通的2D卷积(假设16个滤波器和5×5的核,没有步长/填充)。

如果我们进行计算,最终得到5x5x3x16(5×5的滤波器大小,3个输入通道和16个滤波器)= 1200个核参数 + 16个偏置参数(每个滤波器一个)= 1216。我们可以验证这一点

model = tf.keras.models.Sequential([    tf.keras.layers.Input(shape=(28, 28, 3)),    tf.keras.layers.Conv2D(16, (5, 5)),])model.summary()

给我们

Layer (type)                 Output Shape              Param #   =================================================================conv2d_4 (Conv2D)            (None, 24, 24, 16)        1216

如果我们提取核参数。

print(model.layers[0].get_weights()[0].shape)

这给我们

(5, 5, 3, 16)

现在,让我们考虑可分离的2D卷积,它有2个核,深度核由每个输入通道的单独5x5x1权重矩阵组成,在我们案例中 – 5x5x3(5x5x3x1 – 为了与4D keras张量保持一致)。这给我们75个参数。

点核是一个简单的1×1卷积(它在每个输入点上操作),用于增加结果的深度到指定的滤波器数量。在我们案例中 – 1x1x3x16,这给我们48个参数。

总的来说,我们有第一个核的75个参数和第二个核的48个参数,这给我们123个参数,再加上16个偏置参数。也就是139个参数。

在keras中,

model = tf.keras.models.Sequential([    tf.keras.layers.Input(shape=(28, 28, 3)),    tf.keras.layers.SeparableConv2D(16, (5, 5)),])model.summary()

给我们

Layer (type)                 Output Shape              Param #   =================================================================separable_conv2d_7 (Separabl (None, 24, 24, 16)        139   

如我们所见,这层的输出形状与普通卷积层完全相同,但现在我们有2个核,参数少得多。同样,我们可以提取这两个核的参数,

print(model.layers[0].get_weights()[0].shape)print(model.layers[0].get_weights()[1].shape)

这给我们

(5, 5, 3, 1)(1, 1, 3, 16)

如果你想了解更多关于可分离卷积如何工作的详细信息,可以阅读这篇文章

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注