在Pytorch中,对于以下代码:
torch.arange(0, 3).view(-1, *[1]*3)
结果是:
tensor([[[[0]]], [[[1]]], [[[2]]]]) torch.Size([3, 1, 1, 1])
其中 [1] * 3 = [1, 1, 1],但我不理解 [1] * 3 前面的 * 号。这个 * 号的含义是什么?谢谢。
回答:
虽然评论中提供的链接描述了解决方案的一部分,但可能缺少完整的解释,因此,让我们来解构这个view
方法:
.view(-1,...)
表示“所有元素”,在你的例子中是3,因为你有[0, 1, 2]
,长度为3。
接下来:
[1] * 3
这是Python的一个技巧,用于创建一个新列表,其中单个元素重复多次。
这相当于:
[1, 1, 1]
现在,使用星号进行解包,将值作为函数的参数“解包”,在这种情况下:
.view(-1, [1, 1, 1])
变为:
.view(-1, 1, 1, 1)
整个操作(根据第一步)是:
.view(3, 1, 1, 1)
顺便说一下,在大多数情况下请不要这样做,因为如上所示,这很难理解。