我想从我的原始训练集中创建一个训练+验证集。目录分为训练和测试。我加载了原始训练集,并希望将其划分为训练和验证集,以便在训练过程中使用train_loader和val_loader评估验证损失。
关于这个主题的文档不多,解释得也不清楚。
回答:
查看这里的回答。
我也在下面发布了这个回答。
======================================================
数据是使用ImageFolder
读取的。任务是二元图像分类,数据集中有498张图像,这些图像在两个类别中平均分配(每个类别249张)。
img_dataset = ImageFolder(..., transforms=t)
1. SubsetRandomSampler
dataset_size = len(img_dataset)dataset_indices = list(range(dataset_size))np.random.shuffle(dataset_indices)val_split_index = int(np.floor(0.2 * dataset_size))train_idx, val_idx = dataset_indices[val_split_index:], dataset_indices[:val_split_index]train_sampler = SubsetRandomSampler(train_idx)val_sampler = SubsetRandomSampler(val_idx)train_loader = DataLoader(dataset=img_dataset, shuffle=False, batch_size=8, sampler=train_sampler)validation_loader = DataLoader(dataset=img_dataset, shuffle=False, batch_size=1, sampler=val_sampler)
2. random_split
在这里,总共498张图像中有400张被随机分配到训练集,其余98张分配到验证集。
dataset_train, dataset_valid = random_split(img_dataset, (400, 98))train_loader = DataLoader(dataset=dataset_train, shuffle=True, batch_size=8)val_loader = DataLoader(dataset=dataset_valid, shuffle=False, batch_size=1)
3. WeightedRandomSampler
如果有人在这里搜索
WeightedRandomSampler
,请查看@ptrblck的回答这里,以获得对下文所发生事情的良好解释。
那么,WeightedRandomSampler
如何适合创建训练+验证集呢?因为与SubsetRandomSampler
或random_split()
不同,我们在这里不是为了训练和验证而进行划分。我们只是确保在训练期间每个批次获得相同数量的类别。
所以,我的猜测是我们需要在random_split()
或SubsetRandomSampler
之后使用WeightedRandomSampler
。但这并不能确保训练和验证集在类别之间具有相似的比例。
target_list = []for _, t in imgdataset: target_list.append(t)target_list = torch.tensor(target_list)target_list = target_list[torch.randperm(len(target_list))]# get_class_distribution() 是一个接受数据集并# 返回一个包含类别计数的字典的函数。在这种情况下,# get_class_distribution(img_dataset) 返回以下内容 -# {'class_0': 249, 'class_0': 249}class_count = [i for i in get_class_distribution(img_dataset).values()]class_weights = 1./torch.tensor(class_count, dtype=torch.float) class_weights_all = class_weights[target_list]weighted_sampler = WeightedRandomSampler( weights=class_weights_all, num_samples=len(class_weights_all), replacement=True)