我理解Flatten会移除除了一个维度之外的所有维度。例如,我理解flatten():
> t = torch.ones(4, 3)> ttensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.]])> flatten(t)tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
然而,我不太理解Flatten
,特别是文档中的这个片段的含义:
>>> input = torch.randn(32, 1, 5, 5)>>> m = nn.Sequential(>>> nn.Conv2d(1, 32, 5, 1, 1),>>> nn.Flatten()>>> )>>> output = m(input)>>> output.size()torch.Size([32, 288])
我觉得输出应该大小为[160]
,因为32*5=160
。
问题1. 那么为什么输出大小是[32,288]
?
问题2. 我也不理解文档中给出的shape
信息的含义:
问题3. 还有参数的含义:
回答:
这是默认行为的区别。torch.flatten
默认情况下会展平所有维度,而torch.nn.Flatten
默认情况下会从第二个维度(索引1)开始展平所有维度。
您可以在start_dim
和end_dim
参数的默认值中看到这种行为。参数start_dim
表示要展平的第一个维度(从零开始索引),参数end_dim
表示要展平的最后一个维度。因此,当start_dim=1
时,这是torch.nn.Flatten
的默认值,第一个维度(索引0)不会被展平,但当start_dim=0
时,这是torch.flatten
的默认值,第一个维度会被包括在内。
这种差异背后的原因可能是torch.nn.Flatten
旨在与torch.nn.Sequential
一起使用,通常会对一批输入执行一系列操作,其中每个输入被独立处理。例如,如果您有一批图像并调用torch.nn.Flatten
,典型的用例是分别展平每个图像,而不是展平整个批次。
如果您确实想使用torch.nn.Flatten
展平所有维度,您可以简单地创建对象为torch.nn.Flatten(start_dim=0)
。
最后,文档中的形状信息只是涵盖了张量的形状将如何受到影响,阐明了第一个(索引0)维度保持不变。因此,如果您有一个形状为(N, *dims)
的输入张量,其中*dims
是一系列任意维度,输出张量的形状将为(N, *dims的乘积)
,因为除了批处理维度之外的所有维度都被展平。例如,形状为(3,10,10)
的输入将具有形状为(3, 10 x 10) = (3, 100)
的输出。