我在使用八度卷积,并设置了一个BatchNorm2d的适配,但不知为何出现了以下错误:
RuntimeError: running_mean should contain 64 elements not 0
我设置了一些调试打印来检查我的张量维度是否有问题,但未能找到原因。以下是我的类:
class _BatchNorm2d(nn.Module): def __init__(self, num_features, alpha_in=0, alpha_out=0, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(_BatchNorm2d, self).__init__() hf_ch = int(num_features * (1 - alpha_out)) lf_ch = num_features - hf_ch self.bnh = nn.BatchNorm2d(hf_ch) self.bnl = nn.BatchNorm2d(lf_ch) def forward(self, x): if isinstance(x, tuple): hf, lf = x print("IN ON BN: ",lf.shape if lf is not None else None) #DEBUGGING PRINT print(self.bnl) #DEBUGGING PRINT hf = self.bnh(hf) if type(hf) == torch.Tensor else hf lf = self.bnh(lf) if type(lf) == torch.Tensor else lf #THIS IS THE LINE ACCUSING THE ERROR print("ENDED BN") return hf, lf else: return self.bnh(x)
以下是错误的打印信息:
IN ON BN: torch.Size([32, 64, 3, 3])BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
对我来说,函数应该能正常工作,因为x有64个通道,而bn期望64个通道。
编辑:可能还需要提到的是,错误只在alpha值为1时发生。然而,我不理解其中的原因,因为体积仍然相同。
回答:
已解决。问题出在低频BN的调用上有一个打字错误。
hf = self.bnh(hf) if type(hf) == torch.Tensor else hf lf = self.bnh(lf) if type(lf) == torch.Tensor else lf
应该改为
hf = self.bnh(hf) if type(hf) == torch.Tensor else hf lf = self.bnl(lf) if type(lf) == torch.Tensor else lf