我想使用一种图像增强技术(例如旋转或水平翻转),将其应用于CIFAR-10数据集中的一些图像,并使用PyTorch绘制这些图像。
我知道我们可以使用以下代码来增强图像:
from torchvision import models, datasets, transformsfrom torchvision.datasets import CIFAR10data_transforms = transforms.Compose([ # 添加增强 transforms.RandomHorizontalFlip(p=0.5), # torchvision数据集的输出是范围在[0, 1]的PILImage图像。 # 我们将它们转换为范围在[-1, 1]的归一化张量 transforms.ToTensor(), transforms.Normalize(mean, std) ])
然后我在加载Cifar10数据集时使用了上述变换:
train_set = CIFAR10( root='./data/', train=True, download=True, transform=data_transforms['train'])
据我所知,当使用此代码时,所有的CIFAR10数据集都被变换了。
问题
我的问题是,如何对数据集中的某些图像使用数据变换或增强技术并绘制它们?例如,10张图像及其增强后的图像。
回答:
当使用此代码时,所有的CIFAR10数据集都被变换了
实际上,变换流程只有在用户通过__getitem__
函数或通过数据加载器获取数据集中的图像时才会被调用。因此,此时train_set
并不包含增强后的图像,它们是在使用时动态变换的。
你需要构建另一个没有增强的数据库。
>>> non_augmented = CIFAR10(... root='./data/',... train=True,... download=True)>>> train_set = CIFAR10(... root='./data/',... train=True,... download=True,... transform=data_transforms)
将一些图像堆叠在一起:
>>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)], *[train_set[i][0] for i in range(10)]))>>> imgs.shapetorch.Size([20, 3, 32, 32])
然后torchvision.utils.make_grid
可以用来创建所需的布局:
>>> grid = torchvision.utils.make_grid(imgs, nrow=10)
就这样!
>>> transforms.ToPILImage()(grid)