首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >有没有人为keras写过weldon pooling?

有没有人为keras写过weldon pooling?
EN

Stack Overflow用户
提问于 2018-04-24 03:15:59
回答 1查看 262关注 0票数 0

是否在Keras中实现了Weldon pooling 1?

我可以看到它已经由作者2在pytorch中实现了,但是找不到与keras等效的。

1 T. Durand,N. Thome和M. Cord.韦尔登:深度卷积神经网络的弱监督学习。在CVPR,2016年。2

EN

回答 1

Stack Overflow用户

发布于 2018-04-24 05:57:09

这是一个基于lua版本的(有一个pytorch impl,但我认为取max+min的平均值是错误的)。我假设lua版本的最大和最小值的平均值仍然是正确的。我还没有测试整个自定义的层方面,但是足够近了,可以让一些东西运行起来,欢迎评论。

托尼

代码语言:javascript
复制
class WeldonPooling(Layer):
    """Class to implement Weldon selective spacial pooling with negative evidence
    """

    #@interfaces.legacy_global_pooling_support
    def __init__(self, kmax, kmin=-1, data_format=None, **kwargs):
        super(WeldonPooling, self).__init__(**kwargs)
        self.data_format = conv_utils.normalize_data_format(data_format)
        self.input_spec = InputSpec(ndim=4)
        self.kmax=kmax
        self.kmin=kmin

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            return (input_shape[0], input_shape[3])
        else:
            return (input_shape[0], input_shape[1])

    def get_config(self):
        config = {'data_format': self.data_format}
        base_config = super(_GlobalPooling2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def call(self, inputs):
        if self.data_format == "channels_last":
            inputs = tf.transpose(inputs, [0, 3, 1, 2])
        kmax=self.kmax
        kmin=self.kmin
        shape=tf.shape(inputs)
        batch_size = shape[0]
        num_channels = shape[1]
        h = shape[2]
        w = shape[3]
        n = h * w
        view = tf.reshape(inputs, [batch_size, num_channels, n])
        sorted, indices = tf.nn.top_k(view, n, sorted=True)
        #indices_max = tf.slice(indices,[0,0,0],[batch_size, num_channels, kmax])
        output = tf.div(tf.reduce_sum(tf.slice(sorted,[0,0,0],[batch_size, num_channels, kmax]),2),kmax)

        if kmin > 0:
            #indices_min = tf.slice(indices,[0,0, n-kmin],[batch_size, num_channels, kmin])
            output=tf.add(output,tf.div(tf.reduce_sum(tf.slice(sorted,[0,0,n-kmin],[batch_size, num_channels, kmin]),2),kmin))

        return tf.reshape(output,[batch_size, num_channels])
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49988402

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档