首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >从自己的训练检查点加载用于推理的MMDetection会产生垃圾检测。

从自己的训练检查点加载用于推理的MMDetection会产生垃圾检测。
EN

Stack Overflow用户
提问于 2021-03-08 21:15:10
回答 1查看 2.2K关注 0票数 4

我使用MMDetection协同实验室教程训练了一个非常简单的模型,然后使用以下方法验证结果:

代码语言:javascript
复制
img = mmcv.imread('/content/mmdetection/20210301_145246_123456.jpg')
img = cv2.resize(img, (0,0), fx=0.25, fy=0.25)

model.cfg = cfg
result = inference_detector(model, img)
show_result_pyplot(model, img, result)

证实了它的效果很好。

然后,我遵循与训练相同的步骤,但我加载自己的训练检查点,我不训练。然后,运行上面的验证片段会产生垃圾结果。

这是代码中的

代码语言:javascript
复制
from mmcv import Config
cfg = Config.fromfile('configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.py')

from mmdet.apis import set_random_seed

# Modify dataset type and path
cfg.dataset_type = 'SamplesDataset'
cfg.data_root = 'samples_dataset/'

cfg.data.test.type = 'SamplesDataset'
cfg.data.test.data_root = 'samples_dataset/'
cfg.data.test.ann_file = 'train.txt'
cfg.data.test.img_prefix = 'o2h'

cfg.data.train.type = 'SamplesDataset'
cfg.data.train.data_root = 'samples_dataset/'
cfg.data.train.ann_file = 'train.txt'
cfg.data.train.img_prefix = 'o2h'

cfg.data.val.type = 'SamplesDataset'
cfg.data.val.data_root = 'samples_dataset/'
cfg.data.val.ann_file = 'val.txt'
cfg.data.val.img_prefix = 'o2h'

# modify num classes of the model in box head
cfg.model.roi_head.bbox_head.num_classes = 1
# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
# cfg.load_from = 'checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'
cfg.load_from = './experiments/epoch_1.pth'

# Set up working dir to save files and logs.
cfg.work_dir = './experiments'

# The original learning rate (LR) is set for 8-GPU training.
# We divide it by 8 since we only use one GPU.
cfg.optimizer.lr = 0.02 / 8
cfg.lr_config.warmup = None
cfg.log_config.interval = 10
cfg.runner = dict(type='EpochBasedRunner', max_epochs=1)
cfg.total_epochs = 1

# Change the evaluation metric since we use customized dataset.
cfg.evaluation.metric = 'mAP'
# We can set the evaluation interval to reduce the evaluation times
# cfg.evaluation.interval = 12
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 1

# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)


# We can initialize the logger for training and have a look
# at the final config used for training
# print(f'Config:\n{cfg.pretty_text}')

from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector

# Build dataset
# datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_detector(cfg.model)
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
# mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# train_detector(model, datasets, cfg, distributed=False, validate=True)

显然,我通常不会只为验证我的模型而这么做,但这是我的许多调试步骤之一,因为我的目标是在本地下载和运行该模型。这就是我想在本地做的事情:

代码语言:javascript
复制
import sys
import glob
import time

sys.path.insert(0, '../mmdetection')
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
from mmdet.models import build_detector
import mmcv
import numpy as np

file_paths = glob.glob('samples/o2h/*.jpg')

cfg = mmcv.Config.fromfile('../mmdetection/configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.py')
cfg.model.roi_head.bbox_head.num_classes = 1
cfg.load_from = 'models/mmdet_faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.pth' # my own checkpoint
model = build_detector(cfg.model)
model.CLASSES = ('hash',)
model.cfg = cfg

file_path = np.random.choice(file_paths)
print(file_path)

start = time.time()
result = inference_detector(model, file_path)
print(f"Time taken for inference: {time.time() - start:.2f}s")
show_result_pyplot(model, file_path, result)
EN

回答 1

Stack Overflow用户

发布于 2022-04-14 16:04:50

代码中的一个错误是没有更新num_classes for mask_head

我们在这里的目标应该是复制用于培训的配置文件,也应该用于测试/验证。如果在配置文件中使用1 num_classes bbox_headmask_head对模型进行了培训,但对于验证/测试,则使用80 num_classes作为默认设置,那么这将导致测试过程中的不匹配,导致垃圾检测和分段。

要取得最佳效果,有两个解决方案:

  1. 在进行推理之前更改配置文件中的num_classes
  2. 培训完成后,将模型和配置文件保存为泡菜。

注意:第一个解决方案是最好的。

  1. 在进行推理之前,更改配置文件中的num_classes。 首先,查找数据集中的类总数。这里,num_classes是培训数据集中的总类数。 定位到此路径:mmdetection/configs/model_name (model_name是用于培训的名称) 在这里,在model_name文件夹中,找到用于培训的..._config.py。在这个配置文件中,如果您找到了model = dict(...),那么更改这些键中每个键的num_classesbbox_head, mask_headbbox_head可能是列表。因此,更改列表中每个键的num_classes。 如果找不到model = dict(...),那么第一行是_base_ = '...'所以,打开该配置文件并检查是否找到了model=dict(...)。如果没有找到,请继续打开_base_的文件位置。 更改num_classes后,使用以下代码进行推断:

更改num_classes后的代码:

代码语言:javascript
复制
from mmdet.apis import init_detector, inference_detector
import mmcv
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
%matplotlib inline

config_file = './configs/scnet/scnet_x101_64x4d_fpn_20e_coco.py' #(I have used SCNet for training)

checkpoint_file = 'tutorial_exps/epoch_40.pth' #(checkpoint saved after training)

model = init_detector(config_file, checkpoint_file, device='cuda:0') #loading the model

img = 'test.png'

result = inference_detector(model, img)

#visualize the results in a new window
im1 = cv2.imread(img)[:,:,::-1]
#im_ones = np.ones(im1.shape, dtype='uint')*255
# model.show_result(im_ones, result, out_file='fine_result6.jpg')
plt.imshow(model.show_result(im1, result))

  1. 一旦培训完成,将模型和配置保存为泡菜。 另一种解决方案是在培训完成后立即将模型和配置保存为泡菜,而不管是否依赖于on检测来完成。当您加载默认配置文件(没有任何更改)时,它不会产生所需的结果。您的配置应该与用于培训的配置完全一致。因此,最好将模型和配置保存为泡菜,而不是加载它们。 注意: 在完成培训后,应立即保存泡菜文件。

保存为泡菜的代码:

代码语言:javascript
复制
import pickle

with open('mdl.pkl','wb') as f:
    pickle.dump(model, f)

with open('cfg.pkl','wb') as f:
    pickle.dump(cfg, f)

您可以随时随地使用此模型。对于使用保存的模型进行推理,请使用以下命令:

代码语言:javascript
复制
import pickle, mmcv
from mmdet.apis import inference_detector, show_result_pyplot

model = pickle.load(open('mdl.pkl','rb'))
cfg = pickle.load(open('cfg.pkl','rb'))

img = mmcv.imread('images/test.png')

model.cfg = cfg
result = inference_detector(model, img)
show_result_pyplot(model, img, result)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66537288

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档