我目前正在训练一个用于二元分类的3D CNN,其标签相对稀疏(标签数据中大约1%的体素对应于目标类)。
为了在训练过程中进行基本的健全性检查(例如,网络是否在学习?),能够向网络展示一小部分手动挑选的训练样本,这些样本中目标类标签的比例高于平均水平,将会非常有用。
根据Pytorch文档的建议,我实现了自己的dataset
类(继承自torch.utils.data.Dataset
),通过它的__get_item__
方法向torch.utils.data.DataLoader
提供训练样本。
在Pytorch教程中,我发现DataLoader
被用作迭代器来生成训练循环,如下所示:
for i, data in enumerate(self.dataloader): # 获取训练数据 inputs, labels = data # 训练网络 # [...]
现在我想知道是否有简单的方法可以加载一个或几个特定的训练样本(使用Dataset
的__get_item__
方法理解的线性索引)。然而,DataLoader
没有__get_item__
方法,反复调用__next__
直到达到所需的索引似乎不够优雅。
显然,解决这个问题的一种可能方法是定义一个自定义的sampler
或batch_sampler
,继承自抽象的torch.utils.data.Sampler
。但这似乎是为了检索几个特定样本而过于复杂了。
我猜想我在这里忽略了一些非常简单和明显的东西。欢迎任何建议!
回答:
以防将来有类似问题的人看到这个:
我最终使用的快速且不完美的解决方法是,在训练循环中绕过dataloader
,直接访问其关联的dataset
属性。假设我们想通过反复呈现一个手动挑选的训练样本(其线性索引为sample_idx
,由数据集类定义)来快速检查我们的网络是否在学习。
那么可以这样做:
for i, _ in enumerate(self.dataloader): # 获取训练数据 # inputs, labels = data inputs, labels = self.dataloader.dataset[sample_idx] inputs = inputs.unsqueeze(0) labels = labels.unsqueeze(0) # 训练网络 # [...]
编辑:
简短说明一下,由于一些人发现这个解决方法有帮助:使用这个技巧时,我发现以num_workers = 0
实例化DataLoader
是至关重要的。否则,可能会发生内存分段错误,在这种情况下,你可能会得到非常奇怪的训练数据。