使用PyTorch绘制变换后的(增强后的)图像

我想使用一种图像增强技术(例如旋转或水平翻转),将其应用于CIFAR-10数据集中的一些图像,并使用PyTorch绘制这些图像。

我知道我们可以使用以下代码来增强图像:

from torchvision import models, datasets, transformsfrom torchvision.datasets import CIFAR10data_transforms = transforms.Compose([        # 添加增强        transforms.RandomHorizontalFlip(p=0.5),        # torchvision数据集的输出是范围在[0, 1]的PILImage图像。        # 我们将它们转换为范围在[-1, 1]的归一化张量        transforms.ToTensor(),        transforms.Normalize(mean, std)    ])

然后我在加载Cifar10数据集时使用了上述变换:

train_set = CIFAR10(    root='./data/',    train=True,    download=True,    transform=data_transforms['train'])

据我所知,当使用此代码时,所有的CIFAR10数据集都被变换了。

问题

我的问题是,如何对数据集中的某些图像使用数据变换或增强技术并绘制它们?例如,10张图像及其增强后的图像。


回答:

当使用此代码时,所有的CIFAR10数据集都被变换了

实际上,变换流程只有在用户通过__getitem__函数或通过数据加载器获取数据集中的图像时才会被调用。因此,此时train_set并不包含增强后的图像,它们是在使用时动态变换的。


你需要构建另一个没有增强的数据库。

>>> non_augmented = CIFAR10(...     root='./data/',...     train=True,...     download=True)>>> train_set = CIFAR10(...     root='./data/',...     train=True,...     download=True,...     transform=data_transforms)

将一些图像堆叠在一起:

>>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)],                        *[train_set[i][0] for i in range(10)]))>>> imgs.shapetorch.Size([20, 3, 32, 32])

然后torchvision.utils.make_grid可以用来创建所需的布局:

>>> grid = torchvision.utils.make_grid(imgs, nrow=10)

就这样!

>>> transforms.ToPILImage()(grid)

enter image description here

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注