我正在尝试读取CIFAR10数据集,这些数据集是从https://www.cs.toronto.edu/~kriz/cifar.html>分批给出的。我试图把它放在一个数据帧使用泡菜和读取‘数据’的一部分。但我得到了这个错误。
KeyError Traceback (most recent call last)
<ipython-input-24-8758b7a31925> in <module>()
----> 1 unpickle('datasets/cifar-10-batches-py/test_batch')
<ipython-input-23-04002b89d842> in unpickle(file)
3 fo = open(file, 'rb')
4 dict = pickle.load(fo, encoding ='bytes')
----> 5 X = dict['data']
6 fo.close()
7 return dictKeyError:“数据”。
我正在使用ipython,下面是我的代码:
def unpickle(file):
fo = open(file, 'rb')
dict = pickle.load(fo, encoding ='bytes')
X = dict['data']
fo.close()
return dict
unpickle('datasets/cifar-10-batches-py/test_batch')发布于 2018-08-05 12:22:57
您可以通过下面给出的代码读取cifar 10数据集,只需确保您给出的是放置批的写入目录。
import tensorflow as tf
import pandas as pd
import numpy as np
import math
import timeit
import matplotlib.pyplot as plt
from six.moves import cPickle as pickle
import os
import platform
from subprocess import check_output
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
%matplotlib inline
img_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)
def load_pickle(f):
version = platform.python_version_tuple()
if version[0] == '2':
return pickle.load(f)
elif version[0] == '3':
return pickle.load(f, encoding='latin1')
raise ValueError("invalid python version: {}".format(version))
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename, 'rb') as f:
datadict = load_pickle(f)
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000,3072)
Y = np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
""" load all of cifar """
xs = []
ys = []
for b in range(1,6):
f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
X, Y = load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
Xtr = np.concatenate(xs)
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte
def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):
# Load the raw CIFAR-10 data
cifar10_dir = '../input/cifar-10-batches-py/'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
# Subsample the data
mask = range(num_training, num_training + num_validation)
X_val = X_train[mask]
y_val = y_train[mask]
mask = range(num_training)
X_train = X_train[mask]
y_train = y_train[mask]
mask = range(num_test)
X_test = X_test[mask]
y_test = y_test[mask]
x_train = X_train.astype('float32')
x_test = X_test.astype('float32')
x_train /= 255
x_test /= 255
return x_train, y_train, X_val, y_val, x_test, y_test
# Invoke the above function to get our data.
x_train, y_train, x_val, y_val, x_test, y_test = get_CIFAR10_data()
print('Train data shape: ', x_train.shape)
print('Train labels shape: ', y_train.shape)
print('Validation data shape: ', x_val.shape)
print('Validation labels shape: ', y_val.shape)
print('Test data shape: ', x_test.shape)
print('Test labels shape: ', y_test.shape)发布于 2018-07-30 02:38:43
我知道原因!我也有同样的问题,我解决了!关键问题在于编码方法,将代码更改为
dict = pickle.load(fo, encoding ='bytes')至
dict = pickle.load(fo, encoding ='latin1')发布于 2018-10-27 15:21:46
我过去也经历过类似的问题。
我想为将来的读者提到,您可以找到这里,一个用于自动下载、提取和解析cifar10数据集的python包装器。
https://stackoverflow.com/questions/37512290
复制相似问题