首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow多维argmax

Tensorflow多维argmax
EN

Stack Overflow用户
提问于 2016-04-04 01:06:07
回答 1查看 2.5K关注 0票数 9

假设我有一个大小为BxWxHxD的张量。我想要处理张量,这样我就有了一个新的BxWxHxD张量,其中只保留每个WxH切片中的最大元素,而所有其他值都为零。

换句话说,我认为实现这一点的最好方法是以某种方式跨WxH切片获取2Dargmax,从而产生行和列的BxD索引张量,然后可以将其转换为单热BxWxHxD张量用作掩码。我该怎么做呢?

EN

回答 1

Stack Overflow用户

发布于 2017-11-03 21:36:49

您可以使用以下函数作为起点。它计算每个批次和每个通道的最大元素的索引。结果数组的格式为(批处理大小,2,通道数)。

代码语言:javascript
复制
def argmax_2d(tensor):

  # input format: BxHxWxD
  assert rank(tensor) == 4

  # flatten the Tensor along the height and width axes
  flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3]))

  # argmax of the flat tensor
  argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32)

  # convert indexes into 2D coordinates
  argmax_x = argmax // tf.shape(tensor)[2]
  argmax_y = argmax % tf.shape(tensor)[2]

  # stack and return 2D coordinates
  return tf.stack((argmax_x, argmax_y), axis=1)

def rank(tensor):

  # return the rank of a Tensor
  return len(tensor.get_shape())
票数 8
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/36388431

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档