首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >梯度/损失计算中的解耦脱队列操作

梯度/损失计算中的解耦脱队列操作
EN

Stack Overflow用户
提问于 2016-12-06 17:25:12
回答 1查看 395关注 0票数 1

我目前正在尝试放弃使用提要,开始使用队列,以支持更大的数据集。对于tensorflow中的优化器来说,使用队列很好,因为它们只对每个去队列操作计算一次梯度。但是,我已经实现了与执行行搜索的其他优化器的接口,我不仅需要评估梯度,还需要评估同一批的多个点的损失。不幸的是,对于正常的排队系统,每个损失评估都将执行一个去队列,而不是对同一批进行多次计算。

是否有一种方法使脱队列操作与梯度/损失计算脱钩,使我可以执行一次脱队列,然后在当前批处理上执行多次梯度/损失计算?

编辑:请注意,我的输入张量的大小在批次之间是可变的。我们使用分子数据,每个分子都有不同数量的原子。这与图像数据有很大不同,在图像数据中,所有的东西通常都被缩放到具有相同的维度。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-12-06 19:03:34

通过创建一个变量来将其解耦,存储退出队列的值,然后依赖于这个变量而不是dequeue。在assign期间会提前队列

解决方案#1:固定大小数据,使用变量

代码语言:javascript
复制
(image_batch_live,) = tf.train.batch([image],batch_size=5,num_threads=1,capacity=614)

image_batch = tf.Variable(
  tf.zeros((batch_size, image_size, image_size, color_channels)),
  trainable=False,
  name="input_values_cached")

advance_batch = tf.assign(image_batch, image_batch_live)

现在,image_batch给出队列的最新值,而不对队列进行升级,而advance_batch则对队列进行升级。

解决方案#2:可变大小数据,使用持久张量

在这里,我们通过引入dequeue_opdequeue_op2来解耦工作流。所有的计算都依赖于dequeue_op2dequeue_op的保存值是由它提供的。使用get_session_tensor/get_session_handle可以确保实际数据保留在TensorFlow运行时中,并且通过feed_dict传递的值是一个短字符串标识符。由于dummy_handle的原因,API有点尴尬,我提出了这个问题,这里

代码语言:javascript
复制
import tensorflow as tf
def create_session():
    sess = tf.InteractiveSession(config=tf.ConfigProto(operation_timeout_in_ms=3000))
    return sess

tf.reset_default_graph()

sess = create_session()
dt = tf.int32
dummy_handle = sess.run(tf.get_session_handle(tf.constant(1)))
q = tf.FIFOQueue(capacity=20, dtypes=[dt])
enqueue_placeholder = tf.placeholder(dt, shape=[None])
enqueue_op = q.enqueue(enqueue_placeholder)
dequeue_op = q.dequeue()
size_op = q.size()

dequeue_handle_op = tf.get_session_handle(dequeue_op)
dequeue_placeholder, dequeue_op2 = tf.get_session_tensor(dummy_handle, dt)
compute_op1 = tf.reduce_sum(dequeue_op2)
compute_op2 = tf.reduce_sum(dequeue_op2)+1


# fill queue with variable size data
for i in range(10):
    sess.run(enqueue_op, feed_dict={enqueue_placeholder:[1]*(i+1)})
sess.run(q.close())

try:
    while(True):
        dequeue_handle = sess.run(dequeue_handle_op) # advance the queue
        val1 = sess.run(compute_op1, feed_dict={dequeue_placeholder: dequeue_handle.handle})
        val2 = sess.run(compute_op2, feed_dict={dequeue_placeholder: dequeue_handle.handle})
        size = sess.run(size_op)
        print("val1 %d, val2 %d, queue size %d" % (val1, val2, size))
except tf.errors.OutOfRangeError:
    print("Done")

运行它时,您应该看到如下所示。

代码语言:javascript
复制
val1 1, val2 2, queue size 9
val1 2, val2 3, queue size 8
val1 3, val2 4, queue size 7
val1 4, val2 5, queue size 6
val1 5, val2 6, queue size 5
val1 6, val2 7, queue size 4
val1 7, val2 8, queue size 3
val1 8, val2 9, queue size 2
val1 9, val2 10, queue size 1
val1 10, val2 11, queue size 0
Done
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41001298

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档