首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用tf.gather或tf.gather_nd切片

用tf.gather或tf.gather_nd切片
EN

Stack Overflow用户
提问于 2020-07-30 12:10:06
回答 1查看 497关注 0票数 0

我有一个尺寸为batch_size x actions_space x N_quantiles的张量。为了这个例子,假设尺寸是2,3和4。

代码语言:javascript
复制
x_test = 
 <tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[-0.71722096, -0.36535808, -0.00286232,  0.37722322],
        [ 0.93776643, -1.146626  ,  0.1840729 , -1.427474  ],
        [ 0.47025302, -0.92792755, -0.1490136 ,  1.495174  ]],

       [[-1.3838278 , -0.54772085, -0.14298695,  0.39195213],
        [-0.7986407 ,  0.6419045 , -0.8136323 ,  0.9346474 ],
        [ 0.96690583, -0.82267016, -0.51641494,  0.6930123 ]]],
      dtype=float32)>

对于每一批操作,我都有一个操作的索引,并且我希望为该操作减去分位数值。因此,我希望最终有一个大小为Batch_size x N_Quantiles =2x4的数组。

如果我的操作索引是2,0,那么我想以数组结束:

代码语言:javascript
复制
[[ 0.47025302, -0.92792755, -0.1490136 ,  1.495174  ],
[-1.3838278 , -0.54772085, -0.14298695,  0.39195213 ]].

如何用tf.gather或tf.gather_nd解决这个问题。这应该是非常简单的,但我真的很难提取正确的数组。我试过这样的东西:

代码语言:javascript
复制
tf.gather(x_test, actions, axis=1) 

但是没有什么是正确的

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-07-30 12:52:04

试试tf.gather(x_test, actions, batch_dims=1)

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63172891

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档