首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >从TFRecordDataset获取数据集为numpy数组

从TFRecordDataset获取数据集为numpy数组
EN

Stack Overflow用户
提问于 2018-02-19 17:39:24
回答 1查看 4.9K关注 0票数 5

我正在使用新的tf.data API为CIFAR10数据集创建迭代器。我正在从两个.tfrecord文件中读取数据。一个保存训练数据(train.tfrecords),另一个保存测试数据(test.tfrecords)。一切都很好。然而,在某些时候,我需要两个数据集(训练数据和测试数据)作为numpy数组。

是否可以从tf.data.TFRecordDataset对象中检索作为numpy数组的数据集?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-02-19 19:21:17

您可以使用tf.data.Dataset.batch()转换和tf.contrib.data.get_single_element()来完成这一任务。作为刷新,dataset.batch(n)将处理dataset的连续元素,并通过连接每个组件将它们转换为一个元素。这要求每个组件的所有元素都具有固定的形状。如果n大于dataset中的元素数(或者n没有精确地除以元素数),那么最后一批就可以更小。因此,您可以为n选择一个大值并执行以下操作:

代码语言:javascript
复制
import numpy as np
import tensorflow as tf

# Insert your own code for building `dataset`. For example:
dataset = tf.data.TFRecordDataset(...)  # A dataset of tf.string records.
dataset = dataset.map(...)  # Extract components from each tf.string record.

# Choose a value of `max_elems` that is at least as large as the dataset.
max_elems = np.iinfo(np.int64).max
dataset = dataset.batch(max_elems)

# Extracts the single element of a dataset as one or more `tf.Tensor` objects.
# No iterator needed in this case!
whole_dataset_tensors = tf.contrib.data.get_single_element(dataset)

# Create a session and evaluate `whole_dataset_tensors` to get arrays.
with tf.Session() as sess:
    whole_dataset_arrays = sess.run(whole_dataset_tensors)
票数 6
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48871438

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档