首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >语义分割的样本权重

语义分割的样本权重
EN

Stack Overflow用户
提问于 2020-05-26 03:54:35
回答 1查看 843关注 0票数 1

我尝试用Keras和Tensorflow2后端来解决语义分割问题。我尝试用分类交叉熵将每个像素标记为22个类别中的一个。我输入和输出的形状是

代码语言:javascript
复制
Input: (None, 224, 224, 3)

Output: (None, 224, 224, 23) 22 and 1 for background

我想为每个样本添加权重,以尝试使用我的模型进行伪标记。对于样本权重,我尝试创建一个样本权重数组,它是一个一维数组,长度与batch size相同。但它失败了,并给出了以下错误:

代码语言:javascript
复制
weights can not be broadcast to values. values.rank=3. weights.rank=1.

然后,我尝试给出3D数组( 16,224,224)作为批次大小为16的样本权重,结果显示以下错误:

代码语言:javascript
复制
Found a sample_weight array with shape (16, 224, 224).
In order to use timestep-wise sample weights, you should specify sample_weight_mode="temporal" in compile(). 
If you just mean to use sample-wise weights, make sure your sample_weight array is 1D.
EN

回答 1

Stack Overflow用户

发布于 2020-09-02 15:36:07

我设法用具有多个输出的TF2和在model.compile中使用loss_weights解决了类似的问题。这样,用于训练网络的最终损失是每个损失的加权和。在我的例子中,我有2个类,而不是23个,所以你需要修改这23个类的代码。还要注意的是,代码中的layer c9有3个过滤器,因此您可能需要增加它。

代码语言:javascript
复制
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam
# build your CNN
inputs = Input((224, 224, 1), name='inputs')
c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
...
c9 = Conv2D(3, (3, 3), activation='relu',padding='same')(c9)

output_0 = Conv2D(1, (1, 1), name='class0')(c9)
output_1 = Conv2D(1, (1, 1), name='class1')(c9)

model = Model(inputs=inputs, outputs=[output_0, output_1])
model.compile(optimizer=Adam(learning_rate=0.001, name='adam'),
              loss=[BinaryCrossentropy(from_logits=True), BinaryCrossentropy(from_logits=True)],
              loss_weights=[1, 1000],
              metrics=[["accuracy"], ["accuracy", "mse"]])

此外,您可以在数据生成器中定义采样pdf,以便向网络提供更多属于某个类的对象,而不是背景对象。其效果可能类似于使用加权损失函数。最后,网络尝试最小化平均损失值,因此如果您添加一个类别的更多样本,则该平均损失值将受到该类别的结果的精确影响。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62009764

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档