我想重写nn.Linear的功能。问题在于输入大小为(N, *, in_feature),权重大小为(out_feature, in_feature)。如果我想使用Python得到结果为(N,*,out_feature),我应该如何编写代码?
input @ weight.T
遗憾的是,这是不正确的。
回答:
为了应用@,即__matmul__,大小需要匹配:输入x的形状为(N, *, in_feature),权重张量w的形状为(out_feature, in_feature)。
x = torch.rand(2, 4, 4, 10)w = torch.rand(5, 10)
对w进行转置会得到形状为(in_feature, out_feature)。在x和w.T之间应用__matmul__会减少到形状(N, *, out_feature):
>>> z = [email protected]>>> z.shapetorch.Size([2, 4, 4, 5])
或者等效地使用torch.matmul:
>>> z = torch.matmul(x, w.T)