如何将Pytorch的Dataloader转换为numpy数组以使用matplotlib显示图像数据?

我是Pytorch的新手。我一直在尝试学习如何在开始训练我的CNN之前查看我的输入图像。我很难将图像转换为可以与matplotlib一起使用的形式。

到目前为止,我尝试了以下方法:

from multiprocessing import freeze_supportimport torchfrom torch import nnimport torchvisionfrom torch.autograd import Variablefrom torch.utils.data import DataLoader, Samplerfrom torchvision import datasetsfrom torchvision.transforms import transformsfrom torch.optim import Adamimport matplotlib.pyplot as pltimport numpy as npimport PILnum_classes = 5batch_size = 100num_of_workers = 5DATA_PATH_TRAIN = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\train'DATA_PATH_TEST = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\test'trans = transforms.Compose([    transforms.RandomHorizontalFlip(),    transforms.Resize(32),    transforms.CenterCrop(32),    transforms.ToPImage(),    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))    ])train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)def imshow(img):    img = img / 2 + 0.5     # unnormalize    npimg = img.numpy()    print(npimg)    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))def main():    # get some random training images    dataiter = iter(train_loader)    images, labels = dataiter.next()    # show images    imshow(images)    # print labels    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))if __name__ == "__main__":    main()

然而,这会抛出错误:

  [[0.27058825 0.18431371 0.31764707 ... 0.18823528 0.3882353    0.27450982]   [0.23137254 0.11372548 0.24313724 ... 0.16862744 0.14117646    0.40784314]   [0.25490198 0.19607842 0.30588236 ... 0.27450982 0.25882354    0.34509805]   ...   [0.2784314  0.21960783 0.2352941  ... 0.5803922  0.46666667    0.25882354]   [0.26666668 0.16862744 0.23137254 ... 0.2901961  0.29803923    0.2509804 ]   [0.30980393 0.39607844 0.28627452 ... 0.1490196  0.10588235    0.19607842]]  ... [[[0.8980392  0.8784314  0.8509804  ... 0.627451   0.627451    0.627451  ]   [0.8509804  0.8235294  0.7921569  ... 0.54901963 0.5568628    0.56078434]   [0.7921569  0.7529412  0.7176471  ... 0.47058824 0.48235294    0.49411765]   ...   [0.3764706  0.38431373 0.3764706  ... 0.4509804  0.43137255    0.39607844]   [0.38431373 0.39607844 0.3882353  ... 0.4509804  0.43137255    0.39607844]   [0.3882353  0.4        0.39607844 ... 0.44313726 0.42352942    0.39215687]]  ... [[[0.06274509 0.09019607 0.11372548 ... 0.5803922  0.5176471    0.59607846]   [0.09411764 0.14509803 0.1372549  ... 0.5294118  0.49803922    0.5058824 ]   [0.04705882 0.09411764 0.10196078 ... 0.45882353 0.42352942    0.38431373]   ...   [0.15294117 0.12941176 0.1607843  ... 0.85882354 0.8509804    0.80784315]   [0.14509803 0.10588235 0.1607843  ... 0.8666667  0.85882354    0.8       ]   [0.1490196  0.10588235 0.16470587 ... 0.827451   0.8156863    0.7921569 ]]]Traceback (most recent call last):  File "image_loader.py", line 51, in <module>    main()  File "image_loader.py", line 46, in main    imshow(images)  File "image_loader.py", line 38, in imshow    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))  File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 598, in transpose    return _wrapfunc(a, 'transpose', axes)  File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 51, in _wrapfunc    return getattr(obj, method)(*args, **kwds)ValueError: repeated axis in transpose

我尝试打印出数组以获取维度,但我不知道该如何处理这些数据。这很 confusing。

我的直接问题是:如何在训练之前使用我的DataLoader对象中的张量查看输入图像?


回答:

首先,dataloader输出的是4维张量 – [batch, channel, height, width]。matplotlib和其他图像处理库通常需要[height, width, channel]。你关于使用转置的想法是对的,只是方式不对。

你的images中会有很多图像,所以首先你需要选择一个(或者编写一个for循环来保存所有图像)。这将简单地是images[i],通常我使用i=0

然后,你的转置应该将现在是[channel, height, width]的张量转换为[height, width, channel]的张量。为此,使用np.transpose(image.numpy(), (1, 2, 0)),与你的方法非常相似。

将它们组合在一起,你应该有

plt.imshow(np.transpose(images[0].numpy(), (1, 2, 0)))

有时候你需要调用.detach()(从计算图中分离这部分)和.cpu()(将数据从GPU传输到CPU),这取决于使用情况,将是

plt.imshow(np.transpose(images[0].cpu().detach().numpy(), (1, 2, 0)))

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注