我有一个代表模型的类,其设置如下:
class Model:
def __init__(self):
self.setup_graph()
def setup_graph():
# sets up the model
....
def train(self, dataset):
# dataset is a tf.data.Dataset iterator, from which I can get
# tf.Tensor objects directly, which become part of the graph
....
def predict(self, sample):
# sample is a single NumPy array representing a sample,
# which could be fed to a tf.placeholder using feed_dict
....在培训期间,我希望利用TensorFlow的tf.data.Dataset的效率,但我仍然希望能够在单个样本上获得模型的输出。在我看来,这需要重新创建用于预测的图表。这是真的吗?或者我可以创建一个TF图,在这里我可以使用来自tf.data.Dataset的示例运行,也可以用给定的示例提供给tf.placeholder。
发布于 2018-10-02 14:46:22
您可以像往常一样使用数据集、迭代器等创建模型。然后,如果您想用feed_dict传递一些自定义数据,只需将值传递给get_next()生成的张量即可。
import tensorflow as tf
import numpy as np
dataset = (tf.data.Dataset
.from_tensor_slices(np.ones((100, 3), dtype=np.float32))
.batch(5))
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
output = 2 * batch
with tf.Session() as sess:
print('From iterator:')
print(sess.run(output))
print('From feed_dict:')
print(sess.run(output, feed_dict={batch: [[1, 2, 3]]}))输出:
From iterator:
[[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]]
From feed_dict:
[[2. 4. 6.]]原则上,您可以使用可初始化、可重新初始化或可反馈的迭代器实现相同的效果,但是如果您真的只想测试单个数据样本,我认为这是最快和较少干扰的方法。
https://stackoverflow.com/questions/52610577
复制相似问题