import torchvision.datasets as dsetsimport torchvision.transforms as transformsimport torch.nn.initimport torch.nn.functional as Fdevice = "cuda" if torch.cuda.is_available() else "cpu"print(device)learning_rate = 0.001training_epochs = 15batch_size = 100mnist_train = dsets.MNIST(root='MNIST_data/', # 指定下载路径 train=True, # 指定为True以下载训练数据 transform=transforms.ToTensor(), # 转换为张量 download=True)mnist_test = dsets.MNIST(root='MNIST_data/', # 指定下载路径 train=False, # 指定为False以下载测试数据 transform=transforms.ToTensor(), # 转换为张量 download=True)
这是使用CNN加载MNIST数据分类代码的数据部分
我在参考的书中提到,只需提到那部分就可以查看训练集和测试集中有多少特定数值数据。
例如,你能告诉我训练集或测试集中有多少个’5’的数据吗?
我知道可以通过mnist_train.train_data或mnist_train.train_labels等访问数据张量,但我不知道如何知道特定数值数据的数量。请帮助我
回答:
你可以分别使用data
和targets
属性来访问数据集的任何分割的数据和标签。例如,这里你可以使用mnist_train.data
和mnist_train.labels
分别访问训练数据和标签。
由于这个数据集的targets
属性是一个torch.Tensor
,你可以通过使用torch.bincount
来统计每个目标的实例数量。由于总共有10个类别,输出将是一个长度为10的张量,其中第i个索引指定了类别i的数据点的数量。
示例:
>>> mnist_train = dsets.MNIST(root='MNIST_data/', train=True, transforms.ToTensor(), download=True)>>> mnist_train.targetstensor([5, 0, 4, ..., 5, 6, 8])>>> torch.bincount(mnist_train.targets, minlength=10)tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])
你可以看到类别5在训练分割中拥有5,421个数据点。