我在尝试按照这个教程这里进行操作。虽然当我选择内容图片和风格图片并尝试使用imshow()
函数时,我遇到了这个错误:
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.
通过谷歌我并没有找到解决这个问题的有效方法。
这是我的代码:
import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom PIL import Imageimport matplotlib.pyplot as pltimport torchvision.transforms as transforms import torchvision.models as modelsimport copyimport numpy as np# 检测是否有可用的cuda用于GPU训练,否则将使用CPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)# 输出图像的期望大小imsize = 512 if torch.cuda.is_available() else 256print(imsize)# 辅助函数def image_loader(image_name, imsize): # 调整导入图像的尺寸并将其转换为torch张量 loader = transforms.Compose([transforms.Resize(imsize), transforms.ToTensor()]) image = Image.open(image_name) # 需要假的批次维度以适应网络的输入维度 image = loader(image).unsqueeze(0) return image.to(device, torch.float)# 辅助函数以将张量显示为PIL图像def imshow(tensor, title=None): unloader = transforms.ToPILImage() image = tensor.cpu().clone() image = unloader(image) plt.imshow(image) if title is not None: plt.title(title) plt.pause(0.001) # 暂停以便更新图表# 加载图像image_directory = './images/'style_img = image_loader(image_directory + "pb.jpg", imsize)content_img = image_loader(image_directory + "content.jpg", imsize)assert style_img.size() == content_img.size(), "我们需要导入相同大小的风格和内容图像"plt.figure()imshow(style_img, title='风格图像')
任何建议都将非常有帮助。
这里是供参考的风格和内容图像:
回答:
matplotlib.pyplot
在imshow
函数中期望的是二维(灰度,dimensions=(W,H))或三维(彩色,dimensions = (W,H,color channel))。
您可能在张量中仍然保留了批次大小作为第一个维度,因为在您的代码中,您执行了以下操作:
# 需要假的批次维度以适应网络的输入维度image = loader(image).unsqueeze(0)
这增加了第一个维度。如果是这样,请尝试使用以下方法之一:
plt.imshow(np.squeeze(image))
或
plt.imshow(image[0])