首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何以内存高效的方式将图像数据集读入为numpy数组?

如何以内存高效的方式将图像数据集读入为numpy数组?
EN

Stack Overflow用户
提问于 2019-05-01 09:49:20
回答 1查看 224关注 0票数 0

我试图以numpy数组的形式加载图像的数据集。我如何才能做到这一点,这样我才不会在本地机器上强调RAM的限制,或者创建一个需要太多内存的数组?较大的图像集是训练集,总共约为2 2GB的图像。

这是为了训练残差神经网络,该残差神经网络要求输入数据是数值数组。我使用了模块glob、PIL、skimage、sklearn和numpy来尝试加载图像,但我这样做的方式可能很天真,因为大约2 2GB的图像变成了~17(!)GB数值数组。我已经尝试过寻找解决方案、示例等,但我对Python还不熟悉,所以这个过程非常慢。

用于简单地加载图像的代码是

代码语言:javascript
复制
import glob
from skimage.transform import resize
import numpy as np
from sklearn import datasets
from PIL import Image

def root_2_numpy(data_root):
    """
    Load raw images and output a numpy array of all images and numpy array of labels
    Also preprocesses each image to (224,224) using anti-aliasing
    """
    # load images into numpy array
    all_image_paths = list(data_root.glob('*/*'))  # get image paths
    all_image_paths = [str(path) for path in all_image_paths]  # convert to string
    image_ds = np.zeros([len(all_image_paths), 224, 224,3])  # initialize image dataset
    for i in range(len(all_image_paths)):
        print(i)
        im = Image.open(all_image_paths[i])  # read image as RGB using matplotlib
        if im.mode == 'RGBA' or im.mode == 'L' or im.mode == 'CMYK':
            im = im.convert('RGB')
        elif im.mode =='P':
            im = im.convert('RGBA')
            im = im.convert('RGB')
        im = np.array(im)
        im = resize(im, (224,224), anti_aliasing=True)  # resize image using skimage
        image_ds[i,:,:,:] = im

    # load labels into numpy array
    label_ds = datasets.load_files(data_root, load_content=False, shuffle=False)  # get labels
    n_classes = len(label_ds.target_names)
    Y_ds = np.eye(len(label_ds.target_names))[label_ds.target.reshape(-1)]

    return image_ds, Y_ds, n_classes

我预计这将返回一个大约2 2GB的numpy数组,该数组的维数为(N,W,H,C),用于表示图像的数量、图像宽度、图像高度和3个通道。这不是手头的问题,但我也希望有标签的数据,这些标签是根目录中的类别名称。

除了帮助我高效地加载数据之外,我还非常欣赏我的代码是如何创建这么大的numpy数组的。在我写这篇文章的时候,我有一种感觉,当转换非RBG图像的图像类型时,可能会创建比预期更多的图像。

EN

回答 1

Stack Overflow用户

发布于 2019-05-01 10:51:21

numpy.zeros创建的数组的默认数据类型是64位浮点。所以image_ds = np.zeros([len(all_image_paths), 224, 224,3])创建了一个比你需要的数组大8倍的数组。添加dtype参数,使image_ds的数据类型为uint8 (8位无符号整数):

代码语言:javascript
复制
image_ds = np.zeros([len(all_image_paths), 224, 224,3], dtype=np.uint8)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55930725

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档