我有一个张量 x
并且 x.shape=(batch_size,10)
,现在我想进行以下操作:
x[i][0] = x[i][0]*x[i][1]*...*x[i][9] for i in range(batch_size)
这是我的代码:
for i in range(batch_size): for k in range(1, 10): x[i][0] = x[i][0] * x[i][k]
但是当我在 forward()
中实现这个并调用 loss.backward()
时,反向传播的速度非常慢。为什么会慢,有没有更高效的实现方法?
回答:
速度慢是因为你使用了两个for循环。
你可以使用 .prod
,参见: https://pytorch.org/docs/stable/torch.html#torch.prod
在你的情况下,
x = torch.prod(x, dim=1)
或 x = x.prod(dim=1)
应该可以工作