首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用tf.gather或tf.gather_nd

使用tf.gather或tf.gather_nd
EN

Stack Overflow用户
提问于 2021-03-03 20:17:39
回答 1查看 57关注 0票数 0

我有一个维度为(BATCH_SIZE*A*B*FEATURE_LENGTH)的输入。现在,我想从每个输入样本的每个A块中选择k行(B行)。每个A块的k值是不同的。例如。

代码语言:javascript
复制
inp = ([[[[ 5, 38, 40, 13, 28],
         [12,  6, 36, 20, 23],
         [44, 35, 23, 46,  3]],

        [[22, 32, 36, 20, 42],
         [ 0, 19, 41, 36, 17],
         [ 9, 35, 44,  7, 19]],

        [[27, 10, 22, 10, 48],
         [16, 42, 27,  7, 38],
         [35, 32, 15, 39, 28]]]])
#size (1,3,3,5) = (1,A,B,FEATURE_LENGTH)

现在假设是k=2,也就是说,我想从3个块中的每一个中提取2行。我想要

代码语言:javascript
复制
row 0 and 1 from 1st block
row 1 and 2 from 2nd block
row 0 and 2 from 3rd block

这意味着我希望我的输出看起来像这样

代码语言:javascript
复制
([[[[ 5, 38, 40, 13, 28],
    [12,  6, 36, 20, 23]],
   
    [[ 0, 19, 41, 36, 17],
     [ 9, 35, 44,  7, 19]],

    [[27, 10, 22, 10, 48],
     [35, 32, 15, 39, 28]]]])
#op shape = (1,3,2,5)

我发现在使用tf.gather_nd时,如果我们提供的索引为

代码语言:javascript
复制
ind = array([[[[0, 0, 0], [0, 0, 1]], [[0, 1, 1], [0, 1, 2]], [[0, 2, 0], [0, 2, 2]]]])

但是如果我有大小为(1,16,16,128)k=4的输入,创建这个长的索引序列将变得单调乏味。在Tensorflow-2中有没有更简单的方法?谢谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-03-03 21:12:59

使用带有batch_dims参数的tf.gather()

代码语言:javascript
复制
inds = tf.constant([[[0, 1], [1, 2], [0, 2]]])
output = tf.gather(inp, inds, batch_dims=2)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66456914

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档