我想创建640个全连接层。
(在def init中)
self.fc0 = nn.Linear(120, M)self.fc1 = nn.Linear(120, M).....self.fc638 = nn.Linear(120, M)self.fc639 = nn.Linear(120, M)
(在def forward中)
x[:,:,0,:] = self.fc0(x[:,:,0,:])x[:,:,1,:] = self.fc0(x[:,:,1,:]).......x[:,:,639,:] = self.fc639(x[:,:,639,:])
如何以更简单的方式执行上述代码?
回答:
使用容器:
class MultipleFC(nn.Module): def __init__(self, num_layers, M): self.layers = nn.ModuleList([nn.Linear(120, M) for _ in range(num_layers)]) def forward(self, x): y = torch.empty_like(x) # up to third dim should be M for i, fc in enumerate(self.layers): y[..., i, :] = fc(x[..., i, :]) return y