首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将tf.data.Dataset和NumPy数组提供给模型

将tf.data.Dataset和NumPy数组提供给模型
EN

Stack Overflow用户
提问于 2018-10-02 14:32:24
回答 1查看 433关注 0票数 1

我有一个代表模型的类,其设置如下:

代码语言:javascript
复制
class Model:
  def __init__(self):
    self.setup_graph()

  def setup_graph():
    # sets up the model
    ....

  def train(self, dataset):
    # dataset is a tf.data.Dataset iterator, from which I can get 
    # tf.Tensor objects directly, which become part of the graph
    ....

  def predict(self, sample):
    # sample is a single NumPy array representing a sample,
    # which could be fed to a tf.placeholder using feed_dict
    ....

在培训期间,我希望利用TensorFlow的tf.data.Dataset的效率,但我仍然希望能够在单个样本上获得模型的输出。在我看来,这需要重新创建用于预测的图表。这是真的吗?或者我可以创建一个TF图,在这里我可以使用来自tf.data.Dataset的示例运行,也可以用给定的示例提供给tf.placeholder

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-10-02 14:46:22

您可以像往常一样使用数据集、迭代器等创建模型。然后,如果您想用feed_dict传递一些自定义数据,只需将值传递给get_next()生成的张量即可。

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

dataset = (tf.data.Dataset
    .from_tensor_slices(np.ones((100, 3), dtype=np.float32))
    .batch(5))
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()

output = 2 * batch

with tf.Session() as sess:
    print('From iterator:')
    print(sess.run(output))
    print('From feed_dict:')
    print(sess.run(output, feed_dict={batch: [[1, 2, 3]]}))

输出:

代码语言:javascript
复制
From iterator:
[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
From feed_dict:
[[2. 4. 6.]]

原则上,您可以使用可初始化、可重新初始化或可反馈的迭代器实现相同的效果,但是如果您真的只想测试单个数据样本,我认为这是最快和较少干扰的方法。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52610577

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档