我发现当我对图像使用’ToTensor’时,一张图像变成了9张显示。我查看了官方文档但找不到原因。那么为什么一张图片会变成9张图片呢?问题如图所示。
a = plt.imread('test.jpg')plt.imshow(a)plt.show()
transform = transforms.Compose([transforms.ToTensor()])b = transform(a)b = b.view(375,500,3)plt.imshow(b)
回答:
当你使用 transforms.ToTensor()
时,默认情况下它会将输入数组从 HWC 顺序更改为 CHW 顺序。对于绘图,你需要将通道维度重新推回到最后一个维度。
plt.imshow(b.permute(2, 0, 1))