我正在寻找最好的和优化的方法(没有循环)来使用TF2.0TF2.0TF2.0.在Google Colab中获得张量(等级2张量)的最大值的索引。
import tensorflow as tf
A = tf.constant([[0.2,0.8],[0.3,0.9],[0.4,0.7],[0.5,0.4]])
b = tf.math.argmax(A,0)
bb = b.numpy()这里的最大值的指数是1,1,但问题是我必须给出轴作为输入,即使我改变轴,它也不会给我正确的轴。
发布于 2020-08-16 21:32:24
我不知道这是不是最好的方法,但我发现了这个(来自numpy documentation)可能会对你有帮助:
import numpy as np
c = A.numpy()
np.unravel_index(np.argmax(c, axis=None), c.shape) # outputs (1,1)发布于 2020-08-16 21:54:25
您可以使用Tensorflow函数来完成此操作,如下所示:
max_val = tf.reduce_max(A, keepdims=True)
cond = tf.equal(A, max_val)
res = tf.where(cond)
res
# <tf.Tensor: shape=(1, 2), dtype=int64, numpy=array([[1, 1]], dtype=int64)>如果要在结果中使用一维数组,请添加以下内容:
res_1d = tf.squeeze(res)
res_1d
# <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 1], dtype=int64)>我还没有使用TF1.14,但是我猜你不能为A[res_1d.numpy()[0]]使用.numpy()。但您可以执行以下操作:
tf.slice(A, res_1d, [1, 1])
# <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.9]], dtype=float32)>发布于 2020-08-16 21:54:25
Tou可以将您的nd数组重塑为1d数组并执行argmax。然后计算nd数组中的真实索引
https://stackoverflow.com/questions/63437264
复制相似问题