如何随机打乱Pytorch数据集的标签?

我是Pytorch的新手,遇到了一些技术问题。我已经下载了MNIST数据集,使用了以下命令:

train_dataset = dsets.MNIST(root='./data',                             train=True,                             transform=transforms.ToTensor(),                            download=True)

现在我需要在这个数据集上进行一些实验,但使用随机标签。我该如何随机打乱/重新分配这些标签呢?

我尝试手动操作,但系统提示“’tuple’对象不支持项目分配”。那我该怎么做呢?

第二个问题:如何从数据集中删除一个训练点?当我尝试这样做时,得到的是相同的错误。

谢谢!!


回答:

如果你只想打乱目标,可以使用target_transform参数。例如:

train_dataset = dsets.MNIST(root='./data',                             train=True,                             transform=transforms.ToTensor(),                            target_transform=lambda y: torch.randint(0, 10, (1,)).item(),                            download=True)

如果你想对数据集进行更复杂的调整,可以完全包装mnist

class MyTwistedMNIST(torch.utils.data.Dataset):  def __init__(self, my_args):    super(MyTwistedMNIST, self).__init__()    self.orig_mnist = dset.MNIST(...)    def __getitem__(self, index):    x, y = self.orig_mnist[index]  # 获取原始项目    my_x = # 更改输入数字图像 x ?    my_y = # 更改原始标签 y ?    return my_x, my_y  def __len__(self):    return self.orig_mnist.__len__()

如果你想完全丢弃原始mnist中的一些元素,通过包装原始mnist,你的MyTwistedMNIST类可以返回一个比self.orig_mnist.__len__()小的len,反映你实际想要处理的mnist示例的数量。此外,你需要将新示例的index映射到原始mnist的索引上。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注