Pytorch 错误:ValueError:图片应为二维或三维。得到四维

我在尝试按照这个教程这里进行操作。虽然当我选择内容图片和风格图片并尝试使用imshow()函数时,我遇到了这个错误:

ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

通过谷歌我并没有找到解决这个问题的有效方法。

这是我的代码:

import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom PIL import Imageimport matplotlib.pyplot as pltimport torchvision.transforms as transforms import torchvision.models as modelsimport copyimport numpy as np# 检测是否有可用的cuda用于GPU训练,否则将使用CPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)# 输出图像的期望大小imsize = 512 if torch.cuda.is_available() else 256print(imsize)# 辅助函数def image_loader(image_name, imsize):    # 调整导入图像的尺寸并将其转换为torch张量    loader = transforms.Compose([transforms.Resize(imsize), transforms.ToTensor()])    image = Image.open(image_name)    # 需要假的批次维度以适应网络的输入维度    image = loader(image).unsqueeze(0)    return image.to(device, torch.float)# 辅助函数以将张量显示为PIL图像def imshow(tensor, title=None):    unloader = transforms.ToPILImage()    image = tensor.cpu().clone()    image = unloader(image)    plt.imshow(image)    if title is not None:        plt.title(title)    plt.pause(0.001) # 暂停以便更新图表# 加载图像image_directory = './images/'style_img = image_loader(image_directory + "pb.jpg", imsize)content_img = image_loader(image_directory + "content.jpg", imsize)assert style_img.size() == content_img.size(), "我们需要导入相同大小的风格和内容图像"plt.figure()imshow(style_img, title='风格图像')

任何建议都将非常有帮助。

这里是供参考的风格和内容图像:

输入图像描述

输入图像描述


回答:

matplotlib.pyplotimshow函数中期望的是二维(灰度,dimensions=(W,H))或三维(彩色,dimensions = (W,H,color channel))。

您可能在张量中仍然保留了批次大小作为第一个维度,因为在您的代码中,您执行了以下操作:

# 需要假的批次维度以适应网络的输入维度image = loader(image).unsqueeze(0)

这增加了第一个维度。如果是这样,请尝试使用以下方法之一:

plt.imshow(np.squeeze(image))

plt.imshow(image[0])

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注