我在尝试加载MNIST数据集时遇到了以下错误:
TypeError: 元组索引必须是整数或切片,而不能是字符串
这是我的代码:
import numpy as npimport tensorflow as tfimport tensorflow_datasets as tfdsmnist_dataset = tfds.load(name='mnist', with_info=True, as_supervised=True)mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
这一行代码引发了错误:
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
回答:
如果你使用了with_info=True
参数,你需要相应地进行解包:
mnist_dataset, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
你之前的方式下,mnist_dataset
是一个包含两个项目的字典和一个tfds.core.DatasetInfo
对象的元组:
( {'test': <PrefetchDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>,'train': <PrefetchDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)> }, tfds.core.DatasetInfo(name='mnist', etc))