如何在nn.Sequential
中展平输入
Model = nn.Sequential(x.view(x.shape[0],-1), nn.Linear(784,256), nn.ReLU(), nn.Linear(256,128), nn.ReLU(), nn.Linear(128,64), nn.ReLU(), nn.Linear(64,10), nn.LogSoftmax(dim=1))
回答:
你可以创建一个新的模块/类如下,并在Sequential中像使用其他模块一样使用它(调用Flatten()
)。
class Flatten(torch.nn.Module): def forward(self, x): batch_size = x.shape[0] return x.view(batch_size, -1)
参考:https://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983
编辑:Flatten
现在是torch的一部分。参见https://pytorch.org/docs/stable/nn.html?highlight=flatten#torch.nn.Flatten