从PyTorch数据集中提取子集

我有一个网络,我想在某个数据集上进行训练(例如,假设是CIFAR10)。我可以通过以下方式创建数据加载器对象:

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                        download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,                                          shuffle=True, num_workers=2)

我的问题如下:假设我想进行几次不同的训练迭代。比方说,我首先想用奇数位置的所有图像训练网络,然后用偶数位置的所有图像进行训练,依此类推。为了做到这一点,我需要能够访问这些图像。不幸的是,trainset似乎不允许这种访问。也就是说,尝试trainset[:1000]或者更一般的trainset[mask]会引发错误。

我可以这样做:

trainset.train_data=trainset.train_data[mask]trainset.train_labels=trainset.train_labels[mask]

然后

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,                                              shuffle=True, num_workers=2)

然而,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改了trainset.train_data,所以我需要重新定义trainset)。有什么方法可以避免这种情况吗?

理想情况下,我希望有类似于以下内容的东西:

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,                                              shuffle=True, num_workers=2)

回答:

您可以为数据集加载器定义一个自定义采样器,这样可以避免重新创建数据集(只需为每次不同的采样创建一个新的加载器)。

class YourSampler(Sampler):    def __init__(self, mask):        self.mask = mask    def __iter__(self):        return (self.indices[i] for i in torch.nonzero(self.mask))    def __len__(self):        return len(self.mask)trainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                        download=True, transform=transform)sampler1 = YourSampler(your_mask)sampler2 = YourSampler(your_other_mask)trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,                                          sampler = sampler1, shuffle=False, num_workers=2)trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,                                          sampler = sampler2, shuffle=False, num_workers=2)

补充说明:您可以在这里找到更多信息:http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注