首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >来自tf.nn.max_pool_with_argmax tensorflow的argmax的用法

来自tf.nn.max_pool_with_argmax tensorflow的argmax的用法
EN

Stack Overflow用户
提问于 2018-01-12 05:20:15
回答 3查看 1.4K关注 0票数 2

我正在尝试使用tf.nn.max_pool_with_argmax()的argmax结果来索引另一个张量。为了简单起见,假设我正在尝试实现以下内容:

代码语言:javascript
复制
output, argmax = tf.nn.max_pool_with_argmax(input, ksize, strides, padding)
tf.assert_equal(input[argmax],output)

现在我的问题是,我如何实现必要的索引操作input[argmax]来实现预期的结果?我猜这涉及到tf.gather_nd()和相关调用的一些用法,但我搞不清楚。如果有必要,我们可以假设输入具有[BatchSize, Height, Width, Channel]维度。

感谢您的帮助!

垫子

EN

回答 3

Stack Overflow用户

发布于 2018-01-13 19:11:03

我找到了一个使用tf.gather_nd的解决方案,它很有效,尽管它看起来不是那么优雅。我使用了发布为here的函数unravel_argmax

代码语言:javascript
复制
def unravel_argmax(argmax, shape):
    output_list = []
    output_list.append(argmax // (shape[2] * shape[3]))
    output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
    return tf.stack(output_list)

def max_pool(input, ksize, strides,padding):
    output, arg_max = tf.nn.max_pool_with_argmax(input=input,ksize=ksize,strides=strides,padding=padding)
    shape = input.get_shape()
    arg_max = tf.cast(arg_max,tf.int32)
    unraveld = unravel_argmax(arg_max,shape)
    indices = tf.transpose(unraveld,(1,2,3,4,0))
    channels = shape[-1]
    bs = tf.shape(iv.m)[0]
    t1 = tf.range(channels,dtype=arg_max.dtype)[None, None, None, :, None]
    t2 = tf.tile(t1,multiples=(bs,) + tuple(indices.get_shape()[1:-2]) + (1,1))
    t3 = tf.concat((indices,t2),axis=-1)
    t4 = tf.range(tf.cast(bs, dtype=arg_max.dtype))
    t5 = tf.tile(t4[:,None,None,None,None],(1,) + tuple(indices.get_shape()[1:-2].as_list()) + (channels,1))
    t6 = tf.concat((t5, t3), -1)    
    return tf.gather_nd(input,t6) 

如果有人有更好的解决方案,我仍然很想知道。

垫子

票数 2
EN

Stack Overflow用户

发布于 2019-08-08 23:21:57

这一小段代码可以工作:

代码语言:javascript
复制
def get_results(data,other_tensor):
    pooled_data, indices = tf.nn.max_pool_with_argmax(data,ksize=[1,ksize,ksize,1],strides=[1,stride,stride,1],padding='VALID',include_batch_in_index=True)
    b,w,h,c = other_tensor.get_shape.as_list()
    other_tensor_pooled = tf.gather(tf.reshape(other_tensor,shape= [b*w*h*c,]),indices)
    return other_tensor_pooled

上面的indices可以用来索引张量。这个函数实际上返回扁平化的索引,要将它与batch_size > 1一起使用,需要将include_batch_in_index作为True传递,以便获得正确的结果。我在这里假设您的批处理大小与data.相同

票数 1
EN

Stack Overflow用户

发布于 2018-12-26 03:14:42

我是这样做的:

代码语言:javascript
复制
def max_pool(input, ksize, strides,padding):
    output, arg_max = tf.nn.max_pool_with_argmax(input=input,ksize=ksize,strides=strides,padding=padding)
    shape=tf.shape(output)
    output1=tf.reshape(tf.gather(tf.reshape(input,[-1]),arg_max),shape)

    err=tf.reduce_sum(tf.square(tf.subtract(output,output1)))
    return output1, err
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48215969

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档