我有一个像这样的数据集类:
class LoadDataset(Dataset): def __init__(self, data, label): self.data = data self.label = label def __len__(self): dlen = len(self.data) return dlen def __getitem__(self, index): return self.data, self.label
然后我加载了形状为[485, 1, 32, 32]的图像数据集
train_dataset = LoadDataset(xtrain, ytrain)print(len(train_dataset))# 输出 485
然后我使用DataLoader
加载数据
train_loader = DataLoader(train_dataset, batch_size=32)
接着我遍历数据:
for epoch in range(num_epoch): for inputs, labels in train_loader: print(inputs.shape)
输出打印的是torch.Size([32, 485, 1, 32, 32])
,它应该是torch.Size([32, 1, 32, 32])
,
谁能帮帮我?
回答:
__getitem__
方法应该返回一个数据片段,你返回了所有数据。
试试这个:
class LoadDataset(Dataset): def __init__(self, data, label): self.data = data self.label = label def __len__(self): dlen = len(self.data) llen = len(self.label) # 这里不同 return min(dlen, llen) # 这里不同 def __getitem__(self, index): return self.data[index], self.label[index] # 这里不同