首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在神经网络实现中耗尽内存(使用Numpy Array)。什么能优化数据负载?

在神经网络实现中耗尽内存(使用Numpy Array)。什么能优化数据负载?
EN

Stack Overflow用户
提问于 2019-03-19 18:08:59
回答 1查看 649关注 0票数 0

我的数据集的格式如下:

训练数据

一个大小为numpy的数组(7855,448,448,3),其中(448,448,3)是RGB图像的numpy版本。因为网络的目的是回归,我还没有找到使用ImageDataGenerator的解决方案。因此,我已经将整个图像数据集转换为一个numpy数组。

训练目标

训练目标是一个尺寸为7855的一维numpy阵列.这些条目对应于培训数据的条目。

要获得numpy数组,我必须将整个数据集加载到一个变量中,然后传递给它以进行拟合和预测。仅这一项就占用了5-6个内存。

当拟合模型时,RAM迅速溢出,运行时崩溃。如何将numpy数组元素分批输入,或者是否有另一种方式可以使用以下格式加载数据集:

代码语言:javascript
复制
|list of images |
|labelled       |
|1, 2, 3...     |
|n              |


|csv file with: |
|1   target1    |
|2   target2    |
|3   target3... |

代码https://colab.research.google.com/drive/1FUvPcpYiDtli6vwIaTwacL48RwZ0sq-9

我一直在使用Google,因为这是一个学术研究项目,还没有在高端服务器上投资。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-03-19 19:02:59

您需要使用Dataset API。创建numpy数组时,train_images、train_target使用tf.data.Dataset.from_tensor_slices

代码语言:javascript
复制
dataset = tf.data.Dataset.from_tensor_slices((train_images, train_target))

这将创建dataset对象,该对象可以输入到model.fit中,您可以对其进行洗牌、批处理和将任何解析函数映射到此数据集。您可以控制有多少示例将被预装在混洗缓冲区中。重复控制划时代计数,最好留在None,所以它将无限期地重复。

代码语言:javascript
复制
dataset = dataset.shuffle().repeat()
dataset = dataset.batch()

请记住,批处理在这个管道中进行,所以您不需要在model.fit中使用批处理,而是需要传递每个时期的批处理数和步骤。后者可能有点棘手,因为您不能像len(dataset)那样做一些事情,所以应该提前计算。

代码语言:javascript
复制
model.fit(dataset, epochs, steps_per_epoch)

如果您将遇到graphdef限制错误,最好保存几个较小的numpy数组并将它们作为一个列表传递。

让自己熟悉这个令人毛骨悚然的https://www.tensorflow.org/guide/datasets希望,这是有帮助的。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55247444

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档