num_samples = 10def predict(x): sampled_models = [guide(None, None) for _ in range(num_samples)] yhats = [model(x).data for model in sampled_models] mean = torch.mean(torch.stack(yhats), 0) return np.argmax(mean.numpy(), axis=1)print('当网络被强制预测时的预测结果')correct = 0total = 0for j, data in enumerate(test_loader): images, labels = data predicted = predict(images.view(-1,28*28)) total += labels.size(0) correct += (predicted == labels).sum().item()print("准确率: %d %%" % (100 * correct / total))
错误:
correct += (predicted == labels).sum().item() TypeError: eq() 接收到无效的参数组合 - 接收到 (numpy.ndarray),但期望的是以下之一: * (Tensor other) 不匹配,因为某些参数类型无效:(!numpy.ndarray!)* (Number other) 不匹配,因为某些参数类型无效:(!numpy.ndarray!)
*
回答:
您试图比较 predicted
和 labels
。然而,您的 predicted
是一个 np.array
,而 labels
是一个 torch.tensor
,因此 eq()
(即 ==
操作符)无法在它们之间进行比较。
将 np.argmax
替换为 torch.argmax
:
return torch.argmax(mean, dim=1)
这样应该就可以了。