我有一个自定义的图像数据集及其目标。我已经在PyTorch中创建了一个训练数据集。我想将其拆分为三部分:训练、验证和测试。如何操作呢?
回答:
一旦你有了“主”数据集,你可以使用 data.Subset
来进行拆分。
这是一个随机拆分的示例:
import torchfrom torch.utils import dataimport randommaster = data.Dataset( ... ) # 你的“主”数据集n = len(master) # 总元素数量n_test = int( n * .05 ) # 测试/验证元素的数量n_train = n - 2 * n_testidx = list(range(n)) # 所有元素的索引random.shuffle(idx) # 就地打乱索引以便随机拆分train_idx = idx[:n_train]val_idx = idx[n_train:(n_train + n_test)]test_idx = idx[(n_train + n_test):]train_set = data.Subset(master, train_idx)val_set = data.Subset(master, val_idx)test_set = data.Subset(master, test_idx)
这也可以通过 data.random_split
来实现:
train_set, val_set, test_set = data.random_split(master, (n_train, n_val, n_test))