我创建了一个自定义的Dataset类,它继承自PyTorch的Dataset类,以便处理我已经预处理的自定义数据集。
当我尝试创建DataLoader对象时,遇到了以下错误:
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __init__(self, dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn) 174 if sampler is None: 175 if shuffle:--> 176 sampler = RandomSampler(dataset) 177 else: 178 sampler = SequentialSampler(dataset)/usr/local/lib/python3.6/dist-packages/torch/utils/data/sampler.py in __init__(self, data_source, replacement, num_samples) 62 "since a random permute will be performed.") 63 ---> 64 if not isinstance(self.num_samples, int) or self.num_samples <= 0: 65 raise ValueError("num_samples should be a positive integer " 66 "value, but got num_samples={}".format(self.num_samples))/usr/local/lib/python3.6/dist-packages/torch/utils/data/sampler.py in num_samples(self) 70 # dataset size might change at runtime 71 if self._num_samples is None:---> 72 return len(self.data_source) 73 return self._num_samples 74 /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataset.py in __len__(self) 18 19 def __len__(self):---> 20 raise NotImplementedError 21 22 def __add__(self, other):NotImplementedError:
所以,错误信息是关于dataset.py中len()函数未实现,对吗?但我确实实现了它,还有getitem()和init()函数。
我该如何解决这个问题?谢谢
回答:
请确保代码中的名称正确。它应该是__len__
。