我试图读取CIFAR10数据集,这些数据集是以批次形式提供的,来自https://www.cs.toronto.edu/~kriz/cifar.html。我试图使用pickle将其放入数据框架,并读取其中的’data’部分。但是我遇到了这个错误。
KeyError Traceback (most recent call last)<ipython-input-24-8758b7a31925> in <module>()----> 1 unpickle('datasets/cifar-10-batches-py/test_batch')<ipython-input-23-04002b89d842> in unpickle(file) 3 fo = open(file, 'rb') 4 dict = pickle.load(fo, encoding ='bytes')----> 5 X = dict['data'] 6 fo.close() 7 return dict
KeyError: ‘data’.
我使用的是ipython,以下是我的代码:
def unpickle(file): fo = open(file, 'rb') dict = pickle.load(fo, encoding ='bytes') X = dict['data'] fo.close() return dictunpickle('datasets/cifar-10-batches-py/test_batch')
回答:
您可以使用下面的代码来读取Cifar10数据集,只需确保您提供了存放批次的正确目录即可
import tensorflow as tfimport pandas as pdimport numpy as npimport mathimport timeitimport matplotlib.pyplot as pltfrom six.moves import cPickle as pickleimport osimport platformfrom subprocess import check_outputclasses = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')%matplotlib inlineimg_rows, img_cols = 32, 32input_shape = (img_rows, img_cols, 3)def load_pickle(f): version = platform.python_version_tuple() if version[0] == '2': return pickle.load(f) elif version[0] == '3': return pickle.load(f, encoding='latin1') raise ValueError("invalid python version: {}".format(version))def load_CIFAR_batch(filename): """ load single batch of cifar """ with open(filename, 'rb') as f: datadict = load_pickle(f) X = datadict['data'] Y = datadict['labels'] X = X.reshape(10000,3072) Y = np.array(Y) return X, Ydef load_CIFAR10(ROOT): """ load all of cifar """ xs = [] ys = [] for b in range(1,6): f = os.path.join(ROOT, 'data_batch_%d' % (b, )) X, Y = load_CIFAR_batch(f) xs.append(X) ys.append(Y) Xtr = np.concatenate(xs) Ytr = np.concatenate(ys) del X, Y Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch')) return Xtr, Ytr, Xte, Ytedef get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000): # Load the raw CIFAR-10 data cifar10_dir = '../input/cifar-10-batches-py/' X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir) # Subsample the data mask = range(num_training, num_training + num_validation) X_val = X_train[mask] y_val = y_train[mask] mask = range(num_training) X_train = X_train[mask] y_train = y_train[mask] mask = range(num_test) X_test = X_test[mask] y_test = y_test[mask] x_train = X_train.astype('float32') x_test = X_test.astype('float32') x_train /= 255 x_test /= 255 return x_train, y_train, X_val, y_val, x_test, y_test# Invoke the above function to get our data.x_train, y_train, x_val, y_val, x_test, y_test = get_CIFAR10_data()print('Train data shape: ', x_train.shape)print('Train labels shape: ', y_train.shape)print('Validation data shape: ', x_val.shape)print('Validation labels shape: ', y_val.shape)print('Test data shape: ', x_test.shape)print('Test labels shape: ', y_test.shape)