首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >替换"tf.gather_nd“

替换"tf.gather_nd“
EN

Stack Overflow用户
提问于 2019-06-05 07:58:28
回答 1查看 530关注 0票数 1

我正在做一个项目,但是他们的tensorflow版本不支持tf.gather_nd。我在问,如果可能的话,使用tf.gather,tf.slice或tf.strided_slice重写tf.gather_nd的函数?

tf.gather_nd用于将张量中的切片收集到由索引指定的形状的张量中。详情请访问https://www.tensorflow.org/api_docs/python/tf/gather_nd

谢谢,

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-06-05 17:02:45

此函数应执行相同的工作:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

def my_gather_nd(params, indices):
    idx_shape = tf.shape(indices)
    params_shape = tf.shape(params)
    idx_dims = idx_shape[-1]
    gather_shape = params_shape[idx_dims:]
    params_flat = tf.reshape(params, tf.concat([[-1], gather_shape], axis=0))
    axis_step = tf.cumprod(params_shape[:idx_dims], exclusive=True, reverse=True)
    indices_flat = tf.reduce_sum(indices * axis_step, axis=-1)
    result_flat = tf.gather(params_flat, indices_flat)
    return tf.reshape(result_flat, tf.concat([idx_shape[:-1], gather_shape], axis=0))

# Test
np.random.seed(0)
with tf.Graph().as_default(), tf.Session() as sess:
    params = tf.constant(np.random.rand(10, 20, 30).astype(np.float32))
    indices = tf.constant(np.stack([np.random.randint(10, size=(5, 8)),
                                    np.random.randint(20, size=(5, 8))], axis=-1))
    result1, result2 = sess.run((tf.gather_nd(params, indices),
                                 my_gather_nd(params, indices)))
    print(np.allclose(result1, result2))
    # True
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56452714

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档