我的数据集的格式如下:
训练数据
一个大小为numpy的数组(7855,448,448,3),其中(448,448,3)是RGB图像的numpy版本。因为网络的目的是回归,我还没有找到使用ImageDataGenerator的解决方案。因此,我已经将整个图像数据集转换为一个numpy数组。
训练目标
训练目标是一个尺寸为7855的一维numpy阵列.这些条目对应于培训数据的条目。
要获得numpy数组,我必须将整个数据集加载到一个变量中,然后传递给它以进行拟合和预测。仅这一项就占用了5-6个内存。
当拟合模型时,RAM迅速溢出,运行时崩溃。如何将numpy数组元素分批输入,或者是否有另一种方式可以使用以下格式加载数据集:
|list of images |
|labelled |
|1, 2, 3... |
|n |
|csv file with: |
|1 target1 |
|2 target2 |
|3 target3... |代码https://colab.research.google.com/drive/1FUvPcpYiDtli6vwIaTwacL48RwZ0sq-9
我一直在使用Google,因为这是一个学术研究项目,还没有在高端服务器上投资。
发布于 2019-03-19 19:02:59
您需要使用Dataset API。创建numpy数组时,train_images、train_target使用tf.data.Dataset.from_tensor_slices
dataset = tf.data.Dataset.from_tensor_slices((train_images, train_target))这将创建dataset对象,该对象可以输入到model.fit中,您可以对其进行洗牌、批处理和将任何解析函数映射到此数据集。您可以控制有多少示例将被预装在混洗缓冲区中。重复控制划时代计数,最好留在None,所以它将无限期地重复。
dataset = dataset.shuffle().repeat()
dataset = dataset.batch()请记住,批处理在这个管道中进行,所以您不需要在model.fit中使用批处理,而是需要传递每个时期的批处理数和步骤。后者可能有点棘手,因为您不能像len(dataset)那样做一些事情,所以应该提前计算。
model.fit(dataset, epochs, steps_per_epoch)如果您将遇到graphdef限制错误,最好保存几个较小的numpy数组并将它们作为一个列表传递。
让自己熟悉这个令人毛骨悚然的https://www.tensorflow.org/guide/datasets希望,这是有帮助的。
https://stackoverflow.com/questions/55247444
复制相似问题