首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用Tensorflow获取张量中最大值的索引?

如何使用Tensorflow获取张量中最大值的索引?
EN

Stack Overflow用户
提问于 2020-08-16 21:06:38
回答 3查看 953关注 0票数 2

我正在寻找最好的和优化的方法(没有循环)来使用TF2.0TF2.0TF2.0.在Google Colab中获得张量(等级2张量)的最大值的索引。

代码语言:javascript
复制
import tensorflow as tf
A = tf.constant([[0.2,0.8],[0.3,0.9],[0.4,0.7],[0.5,0.4]])
b = tf.math.argmax(A,0)
bb = b.numpy()

这里的最大值的指数是1,1,但问题是我必须给出轴作为输入,即使我改变轴,它也不会给我正确的轴。

EN

回答 3

Stack Overflow用户

发布于 2020-08-16 21:32:24

我不知道这是不是最好的方法,但我发现了这个(来自numpy documentation)可能会对你有帮助:

代码语言:javascript
复制
import numpy as np

c = A.numpy()
np.unravel_index(np.argmax(c, axis=None), c.shape) # outputs (1,1)
票数 1
EN

Stack Overflow用户

发布于 2020-08-16 21:54:25

您可以使用Tensorflow函数来完成此操作,如下所示:

代码语言:javascript
复制
max_val = tf.reduce_max(A, keepdims=True)
cond = tf.equal(A, max_val)
res = tf.where(cond)
res
# <tf.Tensor: shape=(1, 2), dtype=int64, numpy=array([[1, 1]], dtype=int64)>

如果要在结果中使用一维数组,请添加以下内容:

代码语言:javascript
复制
res_1d = tf.squeeze(res)
res_1d
# <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 1], dtype=int64)>

我还没有使用TF1.14,但是我猜你不能为A[res_1d.numpy()[0]]使用.numpy()。但您可以执行以下操作:

代码语言:javascript
复制
tf.slice(A, res_1d, [1, 1])
# <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.9]], dtype=float32)>
票数 1
EN

Stack Overflow用户

发布于 2020-08-16 21:54:25

Tou可以将您的nd数组重塑为1d数组并执行argmax。然后计算nd数组中的真实索引

票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63437264

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档