首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >基于detectron2的语义分割

基于detectron2的语义分割
EN

Stack Overflow用户
提问于 2022-03-08 14:39:41
回答 1查看 1.2K关注 0票数 1

我使用Detectron2来训练一个使用实例分割的自定义模型,并且工作得很好。在google上有几个使用实例分割的Detectron2教程,但是没有关于语义分割的内容。因此,为了训练自定义实例分割,基于colab (https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5#scrollTo=7unkuuiqLdqd)的代码是:

代码语言:javascript
复制
from detectron2.engine import DefaultTrainer

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("balloon_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 300    # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
cfg.SOLVER.STEPS = []        # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
# NOTE: this config means the number of classes, but a few popular unofficial tutorials incorrect uses num_classes+1 here.

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

为了运行语义分割训练,我将"COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"替换为"/Misc/semantic_R_50_FPN_1x.yaml",基本改变了预训练模型,仅此而已。我得到了一个错误:

代码语言:javascript
复制
TypeError: cross_entropy_loss(): argument 'target' (position 2) must be Tensor, not NoneType

我是如何在Google上建立语义分割的?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-05-01 13:23:17

为了训练语义分割,您可以使用相同的COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml模型。你没必要改变这句话。

您在问题中显示的培训代码是正确的,也可以用于语义分割。所有的改变都是标签文件。

一旦对模型进行了训练,您就可以通过从经过训练的模型加载模型权重来使用它进行推理。

代码语言:javascript
复制
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5   # set the testing threshold for this model
cfg.DATASETS.TEST = ("Detectron_terfspot_" + "test", )                      # the name given to your dataset when loading/registering it
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
predictor = DefaultPredictor(cfg)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71396788

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档