首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >嵌套while_loops的while_loops优化

嵌套while_loops的while_loops优化
EN

Stack Overflow用户
提问于 2017-10-30 16:10:41
回答 1查看 581关注 0票数 1

我正在尝试实现这个方法以用于TensorFlow (取自这里):

代码语言:javascript
复制
def _jacobian_product_sq_euc(X, Y, E, G):
    m = X.shape[0]
    n = Y.shape[0]
    d = X.shape[1]

    for i in range(m):  # 0 - 4
        for j in range(n):
            for k in range(d):
                G[i, k] += E[i,j] * 2 * (X[i, k] - Y[j, k])

我已经用三个tf.while_loops重写了它,但是注意到它非常慢(工作示例这里):

代码语言:javascript
复制
def calc_score():
    gm = tf.zeros([16, 256])

    i = 0
    i_max = 16
    j_max = 16
    d_max = 256

    while_condition_loop1 = lambda i, gm_score: tf.less(i, i_max)
    while_condition_loop2 = lambda i, j, gm_score: tf.less(j, j_max)
    while_condition_loop3 = lambda i, j, d, gm_score: tf.less(d, d_max)
    gm_score = tf.constant(0.)

    def loop3(i, j, d, gm_score):
        gm_score = gm_score + e[i+1, j+1] * 2 * tf.abs((x[i,d] - y[j, d]))
        return [i, j, tf.add(d,1), gm_score]

    def loop2(i, j, gm_score):
        d = 0
        _, _, _, gm_score = tf.while_loop(while_condition_loop3, loop3, [i, j, d, gm_score])
        return [i, tf.add(j,1), gm_score]

    def loop1(i, gm_score):
        j = 0
        _, _, gm_score = tf.while_loop(while_condition_loop2, loop2, [i, j, gm_score])
        return [tf.add(i,1), gm_score]

    _, gm_score = tf.while_loop(while_condition_loop1, loop1, [i, gm_score])

    return gm_score

(注意:我知道在这种情况下,我返回的是单个值,而不是矩阵。但这是另一个问题)

一系列16x256的值大约需要4-5秒来计算。现在我想知道如何优化这个。在这种情况下,除了使用tf.while_loop之外,还有其他选择吗?我的CPU似乎也有相当高的负载,我在训练时收到了很多这样的消息:

2017年-10-30 17:00:51.234993: i tensorflow/core/common_runtime/gpu/pool_allocator.cc:247] PoolAllocator: 257610 get请求后,put_count=385620 evicted_count=128000 eviction_rate=0.331933和未满足的分配rate=0

我对TensorFlow的了解仍然有限,我想知道如何优化这个方法。

我使用python2.7和TensorFlow 1.2.0

EN

回答 1

Stack Overflow用户

发布于 2017-11-03 03:32:00

我不熟悉Jacobian方程,但基本上您应该避免在Tensorflow中遍历元素。您应该在张量/矩阵运算中考虑它,并将tf用于张量运算。我在它们的源代码中找到Jacobian的计算:

checker.py

也许能帮上忙。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47019723

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档