我正在使用31个类别(Office数据集)开发一个图像分类器。每个类别都有一个文件夹。我编写了一个使用PyTorch的Python脚本,该脚本使用datasets.ImageFolder
加载数据集,并为每张图像分配一个标签,然后进行训练。以下是我的加载数据的代码片段:
from torchvision import datasets, transformsimport torchdef load_training(root_path, dir, batch_size, kwargs): transform = transforms.Compose( [transforms.Resize([256, 256]), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor()]) data = datasets.ImageFolder(root=root_path + dir, transform=transform) train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) return train_loader
代码会为每个文件夹内的所有图像分配相同的标签。有没有办法找到哪个标签被分配给了哪个图像/图像文件夹?
回答:
ImageFolder类有一个属性class_to_idx
,它是一个字典,将类别的名称映射到索引(标签)。因此,您可以使用data.classes
访问类别,并通过data.class_to_idx
获取每个类别的标签。
参考资料:https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py