我有一个维度为(BATCH_SIZE*A*B*FEATURE_LENGTH)的输入。现在,我想从每个输入样本的每个A块中选择k行(B行)。每个A块的k值是不同的。例如。
inp = ([[[[ 5, 38, 40, 13, 28],
[12, 6, 36, 20, 23],
[44, 35, 23, 46, 3]],
[[22, 32, 36, 20, 42],
[ 0, 19, 41, 36, 17],
[ 9, 35, 44, 7, 19]],
[[27, 10, 22, 10, 48],
[16, 42, 27, 7, 38],
[35, 32, 15, 39, 28]]]])
#size (1,3,3,5) = (1,A,B,FEATURE_LENGTH)现在假设是k=2,也就是说,我想从3个块中的每一个中提取2行。我想要
row 0 and 1 from 1st block
row 1 and 2 from 2nd block
row 0 and 2 from 3rd block这意味着我希望我的输出看起来像这样
([[[[ 5, 38, 40, 13, 28],
[12, 6, 36, 20, 23]],
[[ 0, 19, 41, 36, 17],
[ 9, 35, 44, 7, 19]],
[[27, 10, 22, 10, 48],
[35, 32, 15, 39, 28]]]])
#op shape = (1,3,2,5)我发现在使用tf.gather_nd时,如果我们提供的索引为
ind = array([[[[0, 0, 0], [0, 0, 1]], [[0, 1, 1], [0, 1, 2]], [[0, 2, 0], [0, 2, 2]]]])但是如果我有大小为(1,16,16,128)和k=4的输入,创建这个长的索引序列将变得单调乏味。在Tensorflow-2中有没有更简单的方法?谢谢!
发布于 2021-03-03 21:12:59
使用带有batch_dims参数的tf.gather():
inds = tf.constant([[[0, 1], [1, 2], [0, 2]]])
output = tf.gather(inp, inds, batch_dims=2)https://stackoverflow.com/questions/66456914
复制相似问题