理解torch.nn.Flatten

我理解Flatten会移除除了一个维度之外的所有维度。例如,我理解flatten()

> t = torch.ones(4, 3)> ttensor([[1., 1., 1.],    [1., 1., 1.],    [1., 1., 1.],    [1., 1., 1.]])> flatten(t)tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

然而,我不太理解Flatten,特别是文档中的这个片段的含义:

>>> input = torch.randn(32, 1, 5, 5)>>> m = nn.Sequential(>>>     nn.Conv2d(1, 32, 5, 1, 1),>>>     nn.Flatten()>>> )>>> output = m(input)>>> output.size()torch.Size([32, 288])

我觉得输出应该大小为[160],因为32*5=160

问题1. 那么为什么输出大小是[32,288]

问题2. 我也不理解文档中给出的shape信息的含义:

enter image description here

问题3. 还有参数的含义:

enter image description here


回答:

这是默认行为的区别。torch.flatten默认情况下会展平所有维度,而torch.nn.Flatten默认情况下会从第二个维度(索引1)开始展平所有维度。

您可以在start_dimend_dim参数的默认值中看到这种行为。参数start_dim表示要展平的第一个维度(从零开始索引),参数end_dim表示要展平的最后一个维度。因此,当start_dim=1时,这是torch.nn.Flatten的默认值,第一个维度(索引0)不会被展平,但当start_dim=0时,这是torch.flatten的默认值,第一个维度会被包括在内。

这种差异背后的原因可能是torch.nn.Flatten旨在与torch.nn.Sequential一起使用,通常会对一批输入执行一系列操作,其中每个输入被独立处理。例如,如果您有一批图像并调用torch.nn.Flatten,典型的用例是分别展平每个图像,而不是展平整个批次。

如果您确实想使用torch.nn.Flatten展平所有维度,您可以简单地创建对象为torch.nn.Flatten(start_dim=0)

最后,文档中的形状信息只是涵盖了张量的形状将如何受到影响,阐明了第一个(索引0)维度保持不变。因此,如果您有一个形状为(N, *dims)的输入张量,其中*dims是一系列任意维度,输出张量的形状将为(N, *dims的乘积),因为除了批处理维度之外的所有维度都被展平。例如,形状为(3,10,10)的输入将具有形状为(3, 10 x 10) = (3, 100)的输出。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注