我在尝试使用TPSSpatialTransformerNetwork训练IAM数据集时,最终得到了一个错误:形状 ‘[-1, 2, 4, 28]’ 对于大小为 768 的输入是无效的
数据集中每张图像的大小为 (32,128)。我无法弄清楚错误步骤中得到的形状。以下是代码:
class TPS_SpatialTransformerNetwork(nn.Module): def __init__(self): super(TPS_SpatialTransformerNetwork, self).__init__() self.conv1 = nn.Conv2d(1, 79, kernel_size=5) self.conv2 = nn.Conv2d(79, 256, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(256, 512) self.fc2 = nn.Linear(512, 79) # Spatial transformer localization-network self.localization = nn.Sequential( nn.Conv2d(1, 8, kernel_size=7), nn.MaxPool2d(2, stride=2), nn.ReLU(True), nn.Conv2d(8, 79, kernel_size=5), nn.MaxPool2d(2, stride=2), nn.ReLU(True) ) # Regressor for the 3 * 2 affine matrix self.fc_loc = nn.Sequential( nn.Linear(79 * 4 * 28, 32), nn.ReLU(True), nn.Linear(32, 3 * 2) ) # Initialize the weights/bias with identity transformation self.fc_loc[2].weight.data.zero_() self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) # Spatial transformer network forward function def stn(self, x): xs = self.localization(x) xs = xs.view(-1, 79 * 4 * 28) theta = self.fc_loc(xs) theta = theta.view(-1, 2, 4,28) grid = F.affine_grid(theta, x.size()) x = F.grid_sample(x, grid) return x def forward(self, x): # transform the input x = self.stn(x) # Perform the usual forward pass x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=1)4 frames/content/drive/My Drive/OCR/transformation.py in stn(self, x) 41 xs = xs.view(-1, 79 * 4 * 28) 42 theta = self.fc_loc(xs)---> 43 theta = theta.view(-1, 2, 4,28) 44 45 RuntimeError: shape '[-1, 2, 4, 28]' is invalid for input of size 768
回答:
我在评论后进行了快速搜索,并找到了这个链接,它详细介绍了如何在pytorch中进行STN(在官方pytorch网站上)。我不知道你是如何得到你的调整大小命令的,但我最初评论中提出的建议似乎是正确的,你试图将6个特征调整到大小为[2,4,28]的矩阵中,这永远不会成功。如下所示,pytorch网站上的做法是:
def stn(self, x): xs = self.localization(x) xs = xs.view(-1, 10 * 3 * 3) theta = self.fc_loc(xs) theta = theta.view(-1, 2, 3) #<-----------KEY LINE HERE grid = F.affine_grid(theta, x.size()) x = F.grid_sample(x, grid) return x
其中theta张量是使用对应于深度为6的维度进行重塑的。
theta张量的大小为6的原因是它由self.fc_loc方法提供,该方法本身如下所示:
self.fc_loc = nn.Sequential( nn.Linear(79 * 4 * 28, 32), nn.ReLU(True), nn.Linear(32, 3 * 2) )
如果我们看最后一行,可以看到这个顺序块的输出(其中每一行都是按照图中的顺序构建的,即顺序!)是一个线性块,具有32个输入和6个输出(3*2)。因此你的theta将是形状[-1, 6],其中-1是这段代码中批次大小的占位符。