我想重写nn.Linear的功能。问题在于输入大小为(N, *, in_feature),权重大小为(out_feature, in_feature)。如果我想使用Python得到结果为(N,*,out_feature),我应该如何编写代码?
input @ weight.T
遗憾的是,这是不正确的。
回答:
为了应用@
,即__matmul__
,大小需要匹配:输入x
的形状为(N, *, in_feature)
,权重张量w
的形状为(out_feature, in_feature)
。
x = torch.rand(2, 4, 4, 10)w = torch.rand(5, 10)
对w
进行转置会得到形状为(in_feature, out_feature)
。在x
和w.T
之间应用__matmul__
会减少到形状(N, *, out_feature)
:
>>> z = [email protected]>>> z.shapetorch.Size([2, 4, 4, 5])
或者等效地使用torch.matmul
:
>>> z = torch.matmul(x, w.T)