理解卷积层形状

我一直在阅读关于卷积神经网络的内容,并且自己也编写了一些模型。当我看到其他模型的可视化图表时,每一层都比前一层更小更深。层有三个维度,比如256x256x32。这个第三个数字是什么?我认为前两个数字是节点的数量,但我不知道深度是什么。


回答:

TLDR; 256x256x32指的是层的输出形状,而不是层本身。


有很多文章和帖子解释了卷积层是如何工作的。我会尽量不涉及太多细节,仅关注形状来回答你的问题。

假设你正在使用2D卷积层,你的输入和输出都将是三维的。也就是说,不考虑批次,这将对应第四个轴…因此,卷积层的输入形状将是(c, h, w)(或(h, w, c),这取决于框架),其中c是通道数,h是输入的高度,w是宽度。你可以将其视为一个具有c通道的hxw图像。最直观的例子是卷积神经网络的第一层卷积层的输入:很可能是一个大小为hxw的图像,具有c个通道,例如c=1代表灰度图像或c=3代表RGB

重要的是,对于该输入的所有像素,每个通道上的值为该像素提供了额外的信息。拥有三个通道将使每个像素(‘像素’指的是2D输入空间中的位置)的内容比只有一个通道更丰富。因为每个像素将使用三个值(三个通道)而不是一个值(一个通道)进行编码。这种关于通道代表什么的直觉可以推广到更多的通道。正如我们所说,输入可以有c个通道。

现在回到卷积层,这里有一个很好的可视化。想象有一个5×5的单通道输入。以及一个由单个3×3滤波器(即kernel_size=3)组成的卷积层

输入 滤波器 卷积 输出
形状 (1, 5, 5) (3, 3) (3,3)
表示

现在请记住,输出的维度将取决于卷积层的步长填充。这里输出的形状与滤波器的形状相同,但不一定必须如此!假设输入形状为(1, 5, 5),使用相同的卷积设置,你最终将得到一个形状为(4, 4)的输出(这与滤波器形状(3, 3)不同)。

还有需要注意的一点是,如果输入有多个通道:形状(c, h, w),滤波器也必须有相同数量的通道。输入的每个通道将与滤波器的每个通道进行卷积,结果将平均成一个单一的2D特征图。因此,你将得到一个中间输出形状为(c, 3, 3),在通道上平均后,将剩下(1, 3, 3)=(3, 3)。结果是,考虑使用单个滤波器的卷积,无论输入有多少个通道,输出总是只有一个通道。

从这里你可以做的是在同一层上组装多个滤波器。这意味着你定义你的层有k3×3滤波器。所以一层包含k个滤波器。对于输出的计算,思路很简单:一个滤波器产生一个(3, 3)的特征图,所以k个滤波器将产生k(3, 3)的特征图。这些图将堆叠成通道维度。最终,你得到的输出形状是…(k, 3, 3)

k_hk_w分别为卷积核的高度和宽度。以及h'w'为一个输出的特征图的高度和宽度:

输入 输出
形状 (c, h, w) (k, c, k_h, k_w) (k, h', w')
描述 c通道hxw特征图 k个形状为(c, k_h, k_w)的滤波器 k通道h'xw'特征图

回到你的问题:

层有三个维度,比如256x256x32。这个第三个数字是什么?我认为前两个数字是节点的数量,但我不知道深度是什么。

卷积层有四个维度,但其中一个由你的输入通道数决定。你可以选择卷积核的大小和滤波器的数量。这个数字将决定输出的通道数。

256×256看起来非常高,你很可能指的是特征图的输出形状。另一方面,32将是输出的通道数,这…正如我试图解释的,是该层中滤波器的数量。通常来说,卷积网络的可视化图表中表示的维度对应于中间输出的形状,而不是层的形状。

例如,以VGG神经网络为例:

!用于大规模图像识别的非常深的卷积网络

VGG的输入形状为(3, 224, 224),知道第一层卷积的结果形状为(64, 224, 224),你可以确定该层总共有64个滤波器。

事实证明,VGG中的核大小为3×3。所以,这里有一个问题给你:知道每个滤波器只有一个偏置参数,VGG的第一层卷积层总共有多少参数?

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注