首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么np.genfromtxt()最初占用大量内存用于大型数据集?

为什么np.genfromtxt()最初占用大量内存用于大型数据集?
EN

Stack Overflow用户
提问于 2018-03-02 19:53:10
回答 2查看 661关注 0票数 0

我有一个包含450,000列和450行的数据集--所有的数值。我使用NumPy函数将数据集加载到一个np.genfromtxt()数组中:

代码语言:javascript
复制
# The skip_header skips over the column names, which is the first row in the file
train = np.genfromtxt('train_data.csv', delimiter=',', skip_header=1)

train_labels = train[:, -1].astype(int)
train_features = train[:, :-1]

当函数最初加载数据集时,它使用15-20 GB以上的RAM。然而,在该函数完成运行后,它的内存使用量仅为2-3 GB。为什么np.genfromtxt()最初消耗这么多内存?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-03-03 00:18:53

@Kasramvd在评论中提出了一个很好的建议,以研究提议的here解决方案。这个答案中的iter_loadtxt()解决方案最终证明是解决我的问题的完美解决方案:

代码语言:javascript
复制
def iter_loadtxt(filename, delimiter=',', skiprows=0, dtype=float):
    def iter_func():
        with open(filename, 'r') as infile:
            for _ in range(skiprows):
                next(infile)
            for line in infile:
                line = line.rstrip().split(delimiter)
                for item in line:
                    yield dtype(item)
        iter_loadtxt.rowlength = len(line)

    data = np.fromiter(iter_func(), dtype=dtype)
    data = data.reshape((-1, iter_loadtxt.rowlength))
    return data

genfromtxt()占用这么多内存的原因是它在解析数据文件时没有将数据存储在高效的NumPy数组中,因此NumPy在解析我的大数据文件时占用了过多的内存。

票数 0
EN

Stack Overflow用户

发布于 2018-03-02 22:15:15

如果提前知道数组的大小,则可以在解析后将每一行加载到目标数组中,从而节省时间和空间。

例如:

代码语言:javascript
复制
In [173]: txt="""1,2,3,4,5,6,7,8,9,10
     ...: 2,3,4,5,6,7,8,9,10,11
     ...: 3,4,5,6,7,8,9,10,11,12
     ...: """

In [174]: np.genfromtxt(txt.splitlines(),dtype=int,delimiter=',',encoding=None)
Out[174]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])

具有更简单的解析功能:

代码语言:javascript
复制
In [177]: def foo(txt,size):
     ...:     out = np.empty(size, int)
     ...:     for i,line in enumerate(txt):
     ...:        out[i,:] = line.split(',')
     ...:     return out
     ...: 
In [178]: foo(txt.splitlines(),(3,10))
Out[178]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])

out[i,:] = line.split(',')将字符串列表加载到数字dtype数组中,强制进行转换,与np.array(line..., dtype=int)相同。

代码语言:javascript
复制
In [179]: timeit np.genfromtxt(txt.splitlines(),dtype=int,delimiter=',',encoding
     ...: =None)
266 µs ± 427 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [180]: timeit foo(txt.splitlines(),(3,10))
19.2 µs ± 169 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

简单、直接的解析器要快得多。

但是,如果我尝试loadtxtgenfromtxt使用的简化版本:

代码语言:javascript
复制
In [184]: def bar(txt):
     ...:     alist=[]
     ...:     for i,line in enumerate(txt):
     ...:        alist.append(line.split(','))
     ...:     return np.array(alist, dtype=int)
     ...: 
     ...: 
In [185]: bar(txt.splitlines())
Out[185]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])
In [186]: timeit bar(txt.splitlines())
13 µs ± 20.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

对于这个小案子,它甚至更快。genfromtxt必须有大量的解析开销。这是一个小样本,所以内存消耗并不重要。

为了完整性,loadtxt

代码语言:javascript
复制
In [187]: np.loadtxt(txt.splitlines(),dtype=int,delimiter=',')
Out[187]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])
In [188]: timeit np.loadtxt(txt.splitlines(),dtype=int,delimiter=',')
103 µs ± 50.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

fromiter

代码语言:javascript
复制
In [206]: def g(txt):
     ...:     for row in txt:
     ...:         for item in row.split(','):
     ...:             yield item
In [209]: np.fromiter(g(txt.splitlines()),dtype=int).reshape(3,10)
Out[209]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])
In [210]: timeit np.fromiter(g(txt.splitlines()),dtype=int).reshape(3,10)
12.3 µs ± 21.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49076386

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档