如何使用PyTorch在自定义图像数据集中创建训练-验证集划分?

我想从我的原始训练集中创建一个训练+验证集。目录分为训练和测试。我加载了原始训练集,并希望将其划分为训练和验证集,以便在训练过程中使用train_loaderval_loader评估验证损失。

关于这个主题的文档不多,解释得也不清楚。


回答:

查看这里的回答。

我也在下面发布了这个回答。

======================================================

数据是使用ImageFolder读取的。任务是二元图像分类,数据集中有498张图像,这些图像在两个类别中平均分配(每个类别249张)。

img_dataset = ImageFolder(..., transforms=t)

1. SubsetRandomSampler

dataset_size = len(img_dataset)dataset_indices = list(range(dataset_size))np.random.shuffle(dataset_indices)val_split_index = int(np.floor(0.2 * dataset_size))train_idx, val_idx = dataset_indices[val_split_index:], dataset_indices[:val_split_index]train_sampler = SubsetRandomSampler(train_idx)val_sampler = SubsetRandomSampler(val_idx)train_loader = DataLoader(dataset=img_dataset, shuffle=False, batch_size=8, sampler=train_sampler)validation_loader = DataLoader(dataset=img_dataset, shuffle=False, batch_size=1, sampler=val_sampler)

2. random_split

在这里,总共498张图像中有400张被随机分配到训练集,其余98张分配到验证集。

dataset_train, dataset_valid = random_split(img_dataset, (400, 98))train_loader = DataLoader(dataset=dataset_train, shuffle=True, batch_size=8)val_loader = DataLoader(dataset=dataset_valid, shuffle=False, batch_size=1)

3. WeightedRandomSampler

如果有人在这里搜索WeightedRandomSampler,请查看@ptrblck的回答这里,以获得对下文所发生事情的良好解释。

那么,WeightedRandomSampler如何适合创建训练+验证集呢?因为与SubsetRandomSamplerrandom_split()不同,我们在这里不是为了训练和验证而进行划分。我们只是确保在训练期间每个批次获得相同数量的类别。

所以,我的猜测是我们需要在random_split()SubsetRandomSampler之后使用WeightedRandomSampler。但这并不能确保训练和验证集在类别之间具有相似的比例。

target_list = []for _, t in imgdataset:    target_list.append(t)target_list = torch.tensor(target_list)target_list = target_list[torch.randperm(len(target_list))]# get_class_distribution() 是一个接受数据集并# 返回一个包含类别计数的字典的函数。在这种情况下,# get_class_distribution(img_dataset) 返回以下内容 -# {'class_0': 249, 'class_0': 249}class_count = [i for i in get_class_distribution(img_dataset).values()]class_weights = 1./torch.tensor(class_count, dtype=torch.float) class_weights_all = class_weights[target_list]weighted_sampler = WeightedRandomSampler(    weights=class_weights_all,    num_samples=len(class_weights_all),    replacement=True)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注