当我试图了解torchvision.datasets.cifar.CIFAR10内部是什么时,我编写了一些简单的代码
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)print(trainset[1])print(trainset[:10])print(type(trainset))
然而,当我尝试时,我得到了一些错误
print(trainset[:10])
错误信息是
TypeError: Cannot handle this data type
我想知道为什么我可以使用trainset[1]
,但不能使用trainset[:10]
?
回答:
CIFAR10不支持切片操作,这就是你得到那个错误的原因。如果你想要前10个,你需要这样做:
print([trainset[i] for i in range(10)])
更多信息
你可以对CIFAR10类的实例进行索引操作的主要原因是该类实现了__getitem__()
函数。
所以,当你调用trainset[i]
时,实际上是在调用trainset.__getitem__(i)
现在,在Python3中,切片表达式也是通过__getitem__()
处理的,其中切片表达式作为一个slice对象传递给__getitem__()
。
因此,trainset[2:10]
相当于trainset.__getitem__(slice(2, 10))
由于传递给__getitem__
的两种不同类型的对象预期执行完全不同的事情,你必须明确处理它们。
不幸的是,CIFAR10类中的__getitem__
方法实现并没有这样做,如下所示:
def __getitem__(self, index): if self.train: img, target = self.train_data[index], self.train_labels[index] else: img, target = self.test_data[index], self.test_labels[index] # 这样做是为了与所有其他数据集保持一致 # 返回一个PIL图像 img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target