首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow:“for”周期的内存分配

tensorflow:“for”周期的内存分配
EN

Stack Overflow用户
提问于 2016-05-10 20:50:56
回答 1查看 418关注 0票数 0

我尝试使用TensorFlow来计算矩阵中每一列与所有其他列(不包括自身)之间的最小欧几里德距离:

代码语言:javascript
复制
with graph.as_default():
  ...
  def get_diversity(matrix):
       num_rows = matrix.get_shape()[0].value
       num_cols = matrix.get_shape()[1].value
       identity = tf.ones([1, num_cols], dtype=tf.float32)
       diversity = 0

       for i in range(num_cols):
           col = tf.reshape(matrix[:, i], [num_rows, 1])
           col_extended_to_matrix = tf.matmul(neuron_matrix, identity)
           difference_matrix = (col_extended_to_matrix - matrix) ** 2
           sum_vector = tf.reduce_sum(difference_matrix, 0)
           mask = tf.greater(sum_vector, 0)
           non_zero_vector = tf.select(mask, sum_vector, tf.ones([num_cols], dtype=tf.float32) * 9e99)
           min_diversity = tf.reduce_min(non_zero_vector)
           diversity += min_diversity

       return diversity / num_cols
  ...

  diversity = get_diversity(matrix1)

  ...

当我每1000次迭代调用一次get_diversity() (在300k的规模上)时,它工作得很好。但是,当我尝试在每次迭代中调用它时,解释器会返回:

代码语言:javascript
复制
W tensorflow/core/common_runtime/bfc_allocator.cc:271] Ran out of memory trying to allocate 2.99MiB.  See logs for memory state.

我认为这是因为TF在每次调用get_diversity()时都会创建一组新的变量。我试过这个:

代码语言:javascript
复制
def get_diversity(matrix, scope):
    scope.reuse_variables()
...
with tf.variable_scope("diversity") as scope:
    diversity = get_diversity(matrix1, scope)

但它并没有解决这个问题。

如何修复此分配问题并使用具有大量迭代的get_diversity()

EN

回答 1

Stack Overflow用户

发布于 2017-03-16 06:24:12

假设您在训练循环中多次调用get_diversity()Aaron's comment是一个很好的选择:相反,您可以执行以下操作:

代码语言:javascript
复制
diversity_input = tf.placeholder(tf.float32, [None, None], name="diversity_input")
diversity = get_diversity(matrix)

# ...

with tf.Session() as sess:
  for _ in range(NUM_ITERATIONS):
    # ...

    diversity_val = sess.run(diversity, feed_dict={diversity_input: ...})

这将避免每次循环时都创建新的操作,从而防止内存泄漏。此answer有更多详细信息。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/37139113

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档