我是Pytorch的新手,遇到了一些技术问题。我已经下载了MNIST数据集,使用了以下命令:
train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
现在我需要在这个数据集上进行一些实验,但使用随机标签。我该如何随机打乱/重新分配这些标签呢?
我尝试手动操作,但系统提示“’tuple’对象不支持项目分配”。那我该怎么做呢?
第二个问题:如何从数据集中删除一个训练点?当我尝试这样做时,得到的是相同的错误。
谢谢!!
回答:
如果你只想打乱目标,可以使用target_transform
参数。例如:
train_dataset = dsets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), target_transform=lambda y: torch.randint(0, 10, (1,)).item(), download=True)
如果你想对数据集进行更复杂的调整,可以完全包装mnist
class MyTwistedMNIST(torch.utils.data.Dataset): def __init__(self, my_args): super(MyTwistedMNIST, self).__init__() self.orig_mnist = dset.MNIST(...) def __getitem__(self, index): x, y = self.orig_mnist[index] # 获取原始项目 my_x = # 更改输入数字图像 x ? my_y = # 更改原始标签 y ? return my_x, my_y def __len__(self): return self.orig_mnist.__len__()
如果你想完全丢弃原始mnist中的一些元素,通过包装原始mnist,你的MyTwistedMNIST
类可以返回一个比self.orig_mnist.__len__()
小的len
,反映你实际想要处理的mnist示例的数量。此外,你需要将新示例的index
映射到原始mnist的索引上。