首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >keras/tensorflow中多类加权损失的语义图像分割

keras/tensorflow中多类加权损失的语义图像分割
EN

Stack Overflow用户
提问于 2019-12-29 23:40:17
回答 3查看 5.3K关注 0票数 3

给定批处理RGB图像作为输入,shape=(batch_size,width,height,3)

和一个表示为one-hot,shape=(batch_size,width,height,n_classes)的多类目标

以及在最后一层具有softmax激活的模型(Unet,DeepLab)。

我在kera/tensorflow中寻找加权分类交叉熵损失函数。

fit_generator中的class_weight参数似乎不起作用,我在这里或在https://github.com/keras-team/keras/issues/2115中都找不到答案。

代码语言:javascript
复制
def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        # y_true, y_pred shape is (batch_size, width, height, n_classes)
        loos = ?...
        return loss

    return wcce
EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2019-12-30 04:57:45

我将回答我的问题:

代码语言:javascript
复制
def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        Kweights = K.constant(weights)
        if not K.is_tensor(y_pred): y_pred = K.constant(y_pred)
        y_true = K.cast(y_true, y_pred.dtype)
        return K.categorical_crossentropy(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1)
    return wcce

用法:

代码语言:javascript
复制
loss = weighted_categorical_crossentropy(weights)
optimizer = keras.optimizers.Adam(lr=0.01)
model.compile(optimizer=optimizer, loss=loss)
票数 4
EN

Stack Overflow用户

发布于 2020-05-01 03:45:23

我使用的是广义骰子损失。在我的例子中,它比加权分类Crossentropy更好。我的实现是用PyTorch实现的,但是翻译起来应该相当容易。

代码语言:javascript
复制
class GeneralizedDiceLoss(nn.Module):
    def __init__(self):
        super(GeneralizedDiceLoss, self).__init__()

    def forward(self, inp, targ):
        inp = inp.contiguous().permute(0, 2, 3, 1)
        targ = targ.contiguous().permute(0, 2, 3, 1)

        w = torch.zeros((targ.shape[-1],))
        w = 1. / (torch.sum(targ, (0, 1, 2))**2 + 1e-9)

        numerator = targ * inp
        numerator = w * torch.sum(numerator, (0, 1, 2))
        numerator = torch.sum(numerator)

        denominator = targ + inp
        denominator = w * torch.sum(denominator, (0, 1, 2))
        denominator = torch.sum(denominator)

        dice = 2. * (numerator + 1e-9) / (denominator + 1e-9)

        return 1. - dice
票数 1
EN

Stack Overflow用户

发布于 2019-12-30 01:22:26

这个问题可能类似于:Unbalanced data and weighted cross entropy,它有一个公认的答案。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59520807

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档