在PyTorch的ImageFolder中创建新观察

我是PyTorch的新手,我想做的事情可能很简单,但我没有在网上找到任何关于在不向图像(在我这里是)文件夹中添加新观察的情况下增加观察数量的信息。我不想向文件夹中添加图像,因为我想尝试不同的变换,看看哪个效果最好,而不必每次都删除图像。所以我做的就是:

trf = transforms.Compose([    transforms.ToTensor(),    transforms.RandomRotation(degrees=45),    transforms.Grayscale(num_output_channels=1),    transforms.Normalize(0, 1),    transforms.functional.invert])train_data = torchvision.datasets.ImageFolder(root='./splitted_data/train', transform= trf)print(len(train_data))train = DataLoader(train_data, batch_size= batch_size, shuffle= True,  num_workers= os.cpu_count())

这里的输出将与所有文件夹中图像的数量相同,这意味着变换已经被应用到现有的观察上,但这并不是我想要达到的。我希望每个变换都是一个独立的副本。我该怎么做呢?


回答:

你可以实现一个变换包装器,它将顺序应用变换并输出每一种变换组合。Torchvision的随机变换的问题在于,当变换被调用时,参数会被抽样。这使得重现相同的变换变得困难。一个替代方案是堆叠或拼接所有图像,并在该堆叠上应用一次变换。

我将变换管道分为三个部分:预处理和后处理变换(后者不应该是随机的,因为它是单独应用的)。至于主要的变换,它们是你想要创建组合的变换列表,这里是RandomRotationGrayscale

请注意,当处理影响通道数量的变换(如Grayscale)时,此解决方案有一定的限制。一般来说,你希望保持相同的张量维度,否则你的拼接和/或堆叠将会失败。

这是一个可能的解决方案:

class Combination(nn.Module):    def __init__(self, transforms, pre, post):        super().__init__()        self.transforms = transforms        self.pre = T.Compose(pre)        self.post = T.Compose(post)    def stacked_t(self, t, x):        lengths = [len(o) for o in x]        return t(torch.cat(x)).split(lengths)    def forward(self, x):        out = [self.pre(x)[None]]        for t in transforms:            out += self.stacked_t(t, out) # <- for every transform `t` we double                                          #    the number of instances in` out`        out = [self.post(o)[0] for o in out]        return out

这里是一个带有输入图像的使用示例:

>>> img

enter image description here

初始化变换组合:

>>> t = Combination(pre=[T.ToTensor()],...                 post=[T.Normalize(0, 1),...                       T.functional.invert],...                 transforms=[T.RandomRotation(degrees=45),...                             T.Grayscale(num_output_channels=1)])

这是不同变换组合的预览:

>>> img_ = t(img)
img_[0] img_[1] img_[2] img_[3]

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注