
场景:语义分割(二分类/多分类)。训练中 mIoU 长期不涨、偶发跳水;可视化发现轮廓“毛边”或标签错位。复盘后发现三件高频问题:
保存为 seg_aug_mismatch_debug.py,分别运行 bug 与修复版本:
python seg_aug_mismatch_debug.py --bug on
python seg_aug_mismatch_debug.py --bug off
# seg_aug_mismatch_debug.py
import argparse
import torch
import torch.nn.functional as F
torch.manual_seed(0)
def make_sample(H=128, W=128):
"""
构造一个简单样本:黑底 + 中心白色矩形(类=1),其余为背景(类=0)
image: [3,H,W] float in [0,1], mask: [1,H,W] long in {0,1},右下角部分置ignore=255
"""
img = torch.zeros(3, H, W)
y1, y2 = H//4, 3*H//4
x1, x2 = W//4, 3*W//4
img[:, y1:y2, x1:x2] = 1.0
mask = torch.zeros(1, H, W, dtype=torch.long)
mask[:, y1:y2, x1:x2] = 1
# 右下角一个 ignore 区域
mask[:, H//2:, W//2:] = 255
return img, mask
def sample_params(H, W, crop=96):
return {
"flip": bool(torch.rand(()) < 0.5),
"crop_y": int(torch.randint(0, H - crop + 1, (1,)).item()),
"crop_x": int(torch.randint(0, W - crop + 1, (1,)).item()),
"crop_h": crop,
"crop_w": crop,
"out_h": H,
"out_w": W,
}
def apply_aug_image(x, p):
# 1) 水平翻转
if p["flip"]:
x = torch.flip(x, dims=[-1])
# 2) 随机裁剪
y0, x0, ch, cw = p["crop_y"], p["crop_x"], p["crop_h"], p["crop_w"]
x = x[:, y0:y0+ch, x0:x0+cw]
# 3) 还原到原始尺寸(双线性)
x = x.unsqueeze(0)
x = F.interpolate(x, size=(p["out_h"], p["out_w"]), mode="bilinear", align_corners=False)
return x.squeeze(0)
def apply_aug_mask(mask, p, *, mode="nearest"):
# 与 image 相同的几何操作,但插值方式可控;同时正确处理 ignore=255
ignore = (mask == 255).float() # [1,H,W]
fg = (mask == 1).float() # 前景通道
bg = (mask == 0).float() # 背景通道(可选)
def geo(t):
# flip
if p["flip"]:
t = torch.flip(t, dims=[-1])
# crop
y0, x0, ch, cw = p["crop_y"], p["crop_x"], p["crop_h"], p["crop_w"]
t = t[:, y0:y0+ch, x0:x0+cw]
# resize
t = t.unsqueeze(0)
t = F.interpolate(t, size=(p["out_h"], p["out_w"]), mode=mode)
return t.squeeze(0)
ignore_t = geo(ignore) # 始终用最近邻(即便 mode=bilinear,ignore 也必须最近邻)
fg_t = geo(fg) # 前景通道
# 组合:先把 ignore 位置设成 255;其余按阈值决定 0/1(对 bilinear 会产生软值)
if mode == "nearest":
# 最近邻:值仍为{0,1}
fg_bin = (fg_t > 0.5).long()
else:
# 双线性:产生[0,1]连续值,后续阈值会造成轮廓毛边
fg_bin = (fg_t > 0.5).long()
out = torch.full_like(fg_bin, 0, dtype=torch.long)
out[fg_bin == 1] = 1
out[ignore_t.squeeze(0) > 0.5] = 255
return out.unsqueeze(0) # [1,H,W] long
def iou_binary(pred, target, ignore=255):
valid = (target != ignore)
p = (pred == 1) & valid
t = (target == 1) & valid
inter = (p & t).sum().float()
union = (p | t).sum().float().clamp_min(1.0)
return (inter / union).item()
def run(bug=True, N=200):
H = W = 128
ious = []
frac_non_binary = []
for _ in range(N):
img, m = make_sample(H, W)
# 生成“图像增强参数”(正确的金标准)
p_img = sample_params(H, W, crop=96)
if bug:
# 错误:掩码单独抽样另一套参数 + 用双线性插值
p_mask = sample_params(H, W, crop=96)
m_bug = apply_aug_mask(m, p_mask, mode="bilinear")
m_ref = apply_aug_mask(m, p_img, mode="nearest")
iou = iou_binary(m_bug, m_ref)
# 统计掩码是否产生了非离散像素(通过双线性导致,阈值前)
# 这里近似:如果插值前后仍保持0/1,我们认为无污染;否则视为被软化
# 为了观测,重跑一次不过阈值的双线性结果
fg = (m == 1).float()
def geo_float(t, p):
if p["flip"]: t = torch.flip(t, [-1])
t = t[:, p["crop_y"]:p["crop_y"]+p["crop_h"], p["crop_x"]:p["crop_x"]+p["crop_w"]]
t = F.interpolate(t.unsqueeze(0), size=(H, W), mode="bilinear").squeeze(0)
return t
soft = geo_float(fg, p_mask)
non01 = float(((soft > 1e-6) & (soft < 1 - 1e-6)).float().mean().item())
frac_non_binary.append(non01)
else:
# 修复:掩码与图像共享同一套几何参数 + 掩码使用最近邻插值
m_fix = apply_aug_mask(m, p_img, mode="nearest")
m_ref = apply_aug_mask(m, p_img, mode="nearest")
iou = iou_binary(m_fix, m_ref)
frac_non_binary.append(0.0)
ious.append(iou)
tag = "BUG" if bug else "FIX"
print(f"[{tag}] mean IoU(image-mask alignment) = {sum(ious)/len(ious):.3f}")
print(f"[{tag}] mean fraction of soft (non-binary) mask pixels ≈ {sum(frac_non_binary)/len(frac_non_binary):.3f}")
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--bug", choices=["on","off"], default="on")
args = ap.parse_args()
run(bug=(args.bug=="on"))你会看到
下面给出一个“成对增强器”,保证图像与掩码共享随机参数;掩码用最近邻;自动遮蔽 ignore_index。
import torch
import torch.nn.functional as F
class SegPairTransform:
def __init__(self, crop=512, out_size=None, hflip_p=0.5, ignore_index=255):
self.crop = crop
self.out_size = out_size
self.hflip_p = hflip_p
self.ignore_index = ignore_index
def _sample(self, H, W):
p = {
"flip": bool(torch.rand(()) < self.hflip_p),
"cy": int(torch.randint(0, max(1, H - self.crop + 1), (1,)).item()),
"cx": int(torch.randint(0, max(1, W - self.crop + 1), (1,)).item()),
"ch": min(self.crop, H),
"cw": min(self.crop, W),
}
if self.out_size is None:
p["oh"], p["ow"] = H, W
else:
p["oh"], p["ow"] = self.out_size
return p
@staticmethod
def _flip(x, do):
return torch.flip(x, [-1]) if do else x
@staticmethod
def _crop(x, y0, x0, h, w):
return x[..., y0:y0+h, x0:x0+w]
def __call.image__(self, img, p):
img = self._flip(img, p["flip"])
img = self._crop(img, p["cy"], p["cx"], p["ch"], p["cw"])
img = F.interpolate(img.unsqueeze(0), size=(p["oh"], p["ow"]), mode="bilinear", align_corners=False).squeeze(0)
return img
def __call.mask__(self, mask, p):
assert mask.dtype == torch.long
ignore = (mask == self.ignore_index).float()
# 逐类处理(二分类示例;多分类可展开 one-hot 再拼回)
out = torch.full_like(mask, 0, dtype=torch.long)
for cls in torch.unique(mask):
if int(cls.item()) in (self.ignore_index,):
continue
ch = (mask == int(cls.item())).float()
ch = self._flip(ch, p["flip"])
ch = self._crop(ch, p["cy"], p["cx"], p["ch"], p["cw"])
ch = F.interpolate(ch.unsqueeze(0), size=(p["oh"], p["ow"]), mode="nearest").squeeze(0)
out[ch > 0.5] = int(cls.item())
# 处理 ignore 覆盖
ig = self._flip(ignore, p["flip"])
ig = self._crop(ig, p["cy"], p["cx"], p["ch"], p["cw"])
ig = F.interpolate(ig.unsqueeze(0), size=(p["oh"], p["ow"]), mode="nearest").squeeze(0)
out[ig > 0.5] = self.ignore_index
return out
def __call__(self, img, mask):
H, W = mask.shape[-2:]
p = self._sample(H, W)
img_t = self.__call.image__(img, p)
mask_t = self.__call.mask__(mask, p)
return img_t, mask_t数据集用法:
class SegDataset(torch.utils.data.Dataset):
def __init__(self, imgs, masks, transform=None):
self.imgs, self.masks, self.t = imgs, masks, transform
def __len__(self): return len(self.imgs)
def __getitem__(self, i):
img, mask = self.imgs[i], self.masks[i]
if self.t is not None:
img, mask = self.t(img, mask)
# 图像做归一化;掩码不归一化
return img.float(), mask.long()ignore_index 在 CE/Dice/IoU 均需屏蔽(loss 与 metric 两侧都加 valid 掩码)。collate_fn 不会对 mask 做 float() 或归一化。autocast 包裹前向即可,mask 保持 long,不进 autocast。def assert_mask_integrity(mask, class_ids, ignore_index=255):
# 1) dtype
assert mask.dtype == torch.long, f"mask dtype should be long, got {mask.dtype}"
# 2) 取值集合
vals = torch.unique(mask)
valid = set(int(v) for v in vals.tolist())
allow = set(class_ids) | {ignore_index}
assert valid.issubset(allow), f"mask values {valid} not subset of {allow}"
# 3) 统计soft像素(若非最近邻导致)
# 若你的管线里有soft掩码,直接告警
# 这里略
def assert_pair_aug_sync(params_img, params_mask):
# 仅用于你自己实现的随机参数记录:应当完全一致
assert params_img == params_mask, "image/mask augment params not synced"训练时周期性统计每类像素占比,若小目标类像素骤降,优先排查增强与插值。
图像-掩码不同步与错误插值,是语义分割里最隐蔽的“训练杀手”。把“共享随机参数 + 掩码最近邻 + 严格 dtype/ignore 处理”固化为模板,再配上一眼就能看出问题的可视化与护栏,这类问题基本可以一次性清零。上面的脚本能在几分钟内重现与验证修复效果,建议纳入项目的 debug 工具箱。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。