我有一个尺寸为batch_size x actions_space x N_quantiles的张量。为了这个例子,假设尺寸是2,3和4。
x_test =
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[-0.71722096, -0.36535808, -0.00286232, 0.37722322],
[ 0.93776643, -1.146626 , 0.1840729 , -1.427474 ],
[ 0.47025302, -0.92792755, -0.1490136 , 1.495174 ]],
[[-1.3838278 , -0.54772085, -0.14298695, 0.39195213],
[-0.7986407 , 0.6419045 , -0.8136323 , 0.9346474 ],
[ 0.96690583, -0.82267016, -0.51641494, 0.6930123 ]]],
dtype=float32)>对于每一批操作,我都有一个操作的索引,并且我希望为该操作减去分位数值。因此,我希望最终有一个大小为Batch_size x N_Quantiles =2x4的数组。
如果我的操作索引是2,0,那么我想以数组结束:
[[ 0.47025302, -0.92792755, -0.1490136 , 1.495174 ],
[-1.3838278 , -0.54772085, -0.14298695, 0.39195213 ]].如何用tf.gather或tf.gather_nd解决这个问题。这应该是非常简单的,但我真的很难提取正确的数组。我试过这样的东西:
tf.gather(x_test, actions, axis=1) 但是没有什么是正确的
发布于 2020-07-30 12:52:04
试试tf.gather(x_test, actions, batch_dims=1)
https://stackoverflow.com/questions/63172891
复制相似问题