我有一个包含450,000列和450行的数据集--所有的数值。我使用NumPy函数将数据集加载到一个np.genfromtxt()数组中:
# The skip_header skips over the column names, which is the first row in the file
train = np.genfromtxt('train_data.csv', delimiter=',', skip_header=1)
train_labels = train[:, -1].astype(int)
train_features = train[:, :-1]当函数最初加载数据集时,它使用15-20 GB以上的RAM。然而,在该函数完成运行后,它的内存使用量仅为2-3 GB。为什么np.genfromtxt()最初消耗这么多内存?
发布于 2018-03-03 00:18:53
@Kasramvd在评论中提出了一个很好的建议,以研究提议的here解决方案。这个答案中的iter_loadtxt()解决方案最终证明是解决我的问题的完美解决方案:
def iter_loadtxt(filename, delimiter=',', skiprows=0, dtype=float):
def iter_func():
with open(filename, 'r') as infile:
for _ in range(skiprows):
next(infile)
for line in infile:
line = line.rstrip().split(delimiter)
for item in line:
yield dtype(item)
iter_loadtxt.rowlength = len(line)
data = np.fromiter(iter_func(), dtype=dtype)
data = data.reshape((-1, iter_loadtxt.rowlength))
return datagenfromtxt()占用这么多内存的原因是它在解析数据文件时没有将数据存储在高效的NumPy数组中,因此NumPy在解析我的大数据文件时占用了过多的内存。
发布于 2018-03-02 22:15:15
如果提前知道数组的大小,则可以在解析后将每一行加载到目标数组中,从而节省时间和空间。
例如:
In [173]: txt="""1,2,3,4,5,6,7,8,9,10
...: 2,3,4,5,6,7,8,9,10,11
...: 3,4,5,6,7,8,9,10,11,12
...: """
In [174]: np.genfromtxt(txt.splitlines(),dtype=int,delimiter=',',encoding=None)
Out[174]:
array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])具有更简单的解析功能:
In [177]: def foo(txt,size):
...: out = np.empty(size, int)
...: for i,line in enumerate(txt):
...: out[i,:] = line.split(',')
...: return out
...:
In [178]: foo(txt.splitlines(),(3,10))
Out[178]:
array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])out[i,:] = line.split(',')将字符串列表加载到数字dtype数组中,强制进行转换,与np.array(line..., dtype=int)相同。
In [179]: timeit np.genfromtxt(txt.splitlines(),dtype=int,delimiter=',',encoding
...: =None)
266 µs ± 427 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [180]: timeit foo(txt.splitlines(),(3,10))
19.2 µs ± 169 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)简单、直接的解析器要快得多。
但是,如果我尝试loadtxt和genfromtxt使用的简化版本:
In [184]: def bar(txt):
...: alist=[]
...: for i,line in enumerate(txt):
...: alist.append(line.split(','))
...: return np.array(alist, dtype=int)
...:
...:
In [185]: bar(txt.splitlines())
Out[185]:
array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])
In [186]: timeit bar(txt.splitlines())
13 µs ± 20.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)对于这个小案子,它甚至更快。genfromtxt必须有大量的解析开销。这是一个小样本,所以内存消耗并不重要。
为了完整性,loadtxt
In [187]: np.loadtxt(txt.splitlines(),dtype=int,delimiter=',')
Out[187]:
array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])
In [188]: timeit np.loadtxt(txt.splitlines(),dtype=int,delimiter=',')
103 µs ± 50.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)用fromiter
In [206]: def g(txt):
...: for row in txt:
...: for item in row.split(','):
...: yield item
In [209]: np.fromiter(g(txt.splitlines()),dtype=int).reshape(3,10)
Out[209]:
array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])
In [210]: timeit np.fromiter(g(txt.splitlines()),dtype=int).reshape(3,10)
12.3 µs ± 21.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)https://stackoverflow.com/questions/49076386
复制相似问题