首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用register_coco_instances加载的数据集进行检测2的增强

如何使用register_coco_instances加载的数据集进行检测2的增强
EN

Stack Overflow用户
提问于 2022-04-06 23:39:09
回答 1查看 2.1K关注 0票数 2

我已经训练了一个关于自定义数据的detectron2模型,我用coco格式标记和导出数据,但是现在我想应用增强和使用增强的数据进行训练。如果我使用的不是自定义的DataLoader,而是register_coco_instances函数,我如何做到这一点。

代码语言:javascript
复制
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
outputs = predictor(im)

train_annotations_path = "./data/cvat-corn-train-coco-1.0/annotations/instances_default.json"
train_images_path = "./data/cvat-corn-train-coco-1.0/images"
validation_annotations_path = "./data/cvat-corn-validation-coco-1.0/annotations/instances_default.json"
validation_images_path = "./data/cvat-corn-validation-coco-1.0/images"

register_coco_instances(
    "train-corn",
    {},
    train_annotations_path,
    train_images_path
)
register_coco_instances(
    "validation-corn",
    {},
    validation_annotations_path,
    validation_images_path
)
metadata_train = MetadataCatalog.get("train-corn")
dataset_dicts = DatasetCatalog.get("train-corn")

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("train-corn",)
cfg.DATASETS.TEST = ("validation-corn",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 10000
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 4
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

我在文档中看到,您可以加载数据集并应用如下增强:

代码语言:javascript
复制
dataloader = build_detection_train_loader(cfg,
   mapper=DatasetMapper(cfg, is_train=True, augmentations=[
      T.Resize((800, 800))
   ]))

但是我没有使用自定义数据处理程序,那么最好的方法是什么呢?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-04-08 15:18:38

根据我的经验,如何注册您的数据集(即告诉Detectron2如何获得一个名为"my_dataset"的数据集)与培训期间使用的数据中心无关(即如何从注册的数据集加载信息并将其处理为模型所需的格式)。

因此,您可以任意注册数据集--要么使用register_coco_instances函数,要么直接使用dataset API (DatasetCatalogMetadataCatalog);这无关紧要。重要的是要在数据加载部分应用一些转换。

基本上,您希望自定义数据加载部分,该部分只能通过使用自定义数据加载器来实现(除非您执行脱机增强(这可能不是您想要的)。

现在,您不需要直接在顶层代码中定义和使用自定义数据器。您可以从DefaultTrainer中创建您自己的培训器,并重写它的build_train_loader方法。这很简单,如下所示。

代码语言:javascript
复制
class MyTrainer(DefaultTrainer):

    @classmethod
    def build_train_loader(cls, cfg):
        mapper = DatasetMapper(cfg, is_train=True, augmentations=[T.Resize((800, 800))])
        return build_detection_train_loader(cfg, mapper=mapper)

那么,在您的顶层代码中,唯一需要的更改是使用MyTrainer而不是DefaultTrainer

代码语言:javascript
复制
trainer = MyTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71774744

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档