以批次读取Cifar10数据集

我试图读取CIFAR10数据集,这些数据集是以批次形式提供的,来自https://www.cs.toronto.edu/~kriz/cifar.html。我试图使用pickle将其放入数据框架,并读取其中的’data’部分。但是我遇到了这个错误。

KeyError                                  Traceback (most recent call last)<ipython-input-24-8758b7a31925> in <module>()----> 1 unpickle('datasets/cifar-10-batches-py/test_batch')<ipython-input-23-04002b89d842> in unpickle(file)      3     fo = open(file, 'rb')      4     dict = pickle.load(fo, encoding ='bytes')----> 5     X = dict['data']      6     fo.close()      7     return dict

KeyError: ‘data’.

我使用的是ipython,以下是我的代码:

def unpickle(file): fo = open(file, 'rb') dict = pickle.load(fo, encoding ='bytes') X = dict['data'] fo.close() return dictunpickle('datasets/cifar-10-batches-py/test_batch')

回答:

您可以使用下面的代码来读取Cifar10数据集,只需确保您提供了存放批次的正确目录即可

import tensorflow as tfimport pandas as pdimport numpy as npimport mathimport timeitimport matplotlib.pyplot as pltfrom six.moves import cPickle as pickleimport osimport platformfrom subprocess import check_outputclasses = ('plane', 'car', 'bird', 'cat',           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')%matplotlib inlineimg_rows, img_cols = 32, 32input_shape = (img_rows, img_cols, 3)def load_pickle(f):    version = platform.python_version_tuple()    if version[0] == '2':        return  pickle.load(f)    elif version[0] == '3':        return  pickle.load(f, encoding='latin1')    raise ValueError("invalid python version: {}".format(version))def load_CIFAR_batch(filename):    """ load single batch of cifar """    with open(filename, 'rb') as f:        datadict = load_pickle(f)        X = datadict['data']        Y = datadict['labels']        X = X.reshape(10000,3072)        Y = np.array(Y)        return X, Ydef load_CIFAR10(ROOT):    """ load all of cifar """    xs = []    ys = []    for b in range(1,6):        f = os.path.join(ROOT, 'data_batch_%d' % (b, ))        X, Y = load_CIFAR_batch(f)        xs.append(X)        ys.append(Y)    Xtr = np.concatenate(xs)    Ytr = np.concatenate(ys)    del X, Y    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))    return Xtr, Ytr, Xte, Ytedef get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):    # Load the raw CIFAR-10 data    cifar10_dir = '../input/cifar-10-batches-py/'    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)    # Subsample the data    mask = range(num_training, num_training + num_validation)    X_val = X_train[mask]    y_val = y_train[mask]    mask = range(num_training)    X_train = X_train[mask]    y_train = y_train[mask]    mask = range(num_test)    X_test = X_test[mask]    y_test = y_test[mask]    x_train = X_train.astype('float32')    x_test = X_test.astype('float32')    x_train /= 255    x_test /= 255    return x_train, y_train, X_val, y_val, x_test, y_test# Invoke the above function to get our data.x_train, y_train, x_val, y_val, x_test, y_test = get_CIFAR10_data()print('Train data shape: ', x_train.shape)print('Train labels shape: ', y_train.shape)print('Validation data shape: ', x_val.shape)print('Validation labels shape: ', y_val.shape)print('Test data shape: ', x_test.shape)print('Test labels shape: ', y_test.shape)

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注