首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何加载定制的yolo v-7训练模型

如何加载定制的yolo v-7训练模型
EN

Stack Overflow用户
提问于 2022-09-02 06:39:23
回答 1查看 1.6K关注 0票数 1

如何加载自定义的yolo v-7型号。

我知道如何加载yolo v-5模型:

代码语言:javascript
复制
model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/runs/train/exp15/weights/last.pt', force_reload=True)

我在网上看过视频,他们建议用这个:

代码语言:javascript
复制
!python detect.py --weights runs/train/yolov7x-custom/weights/best.pt --conf 0.5 --img-size 640 --source final_test_v1.mp4 

但我希望它像一个普通的模型一样加载,并给我在哪里找到物体的包围盒协调。

这就是我在yolo V-5中是怎么做到的:

代码语言:javascript
复制
from models.experimental import attempt_load
yolov5_weight_file = r'weights/rider_helmet_number_medium.pt' # ... may need full path
model = attempt_load(yolov5_weight_file, map_location=device)

def object_detection(frame):
    img = torch.from_numpy(frame)
    img = img.permute(2, 0, 1).float().to(device)  #convert to required shape based on index
    img /= 255.0  
    if img.ndimension() == 3:
        img = img.unsqueeze(0)

    pred = model(img, augment=False)[0]
    pred = non_max_suppression(pred, conf_set, 0.20) # prediction, conf, iou
    # print(pred)
    detection_result = []
    for i, det in enumerate(pred):
        if len(det): 
            for d in det: # d = (x1, y1, x2, y2, conf, cls)
                x1 = int(d[0].item())
                y1 = int(d[1].item())
                x2 = int(d[2].item())
                y2 = int(d[3].item())
                conf = round(d[4].item(), 2)
                c = int(d[5].item())
                
                detected_name = names[c]

                # print(f'Detected: {detected_name} conf: {conf}  bbox: x1:{x1}    y1:{y1}    x2:{x2}    y2:{y2}')
                detection_result.append([x1, y1, x2, y2, conf, c])
                
                frame = cv2.rectangle(frame, (x1, y1), (x2, y2), (255,0,0), 1) # box
                if c!=1: # if it is not head bbox, then write use putText
                    frame = cv2.putText(frame, f'{names[c]} {str(conf)}', (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 1, cv2.LINE_AA)

    return (frame, detection_result)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-09-02 07:53:00

您不能使用attempt_load从Yolov5回购,因为这个方法是指向超解析释放文件。您需要使用attempt_load从Yolov7回购,因为这是指向正确的文件。

代码语言:javascript
复制
# yolov7
def attempt_download(file, repo='WongKinYiu/yolov7'):
    # Attempt file download if does not exist
    file = Path(str(file).strip().replace("'", '').lower())
...
代码语言:javascript
复制
# yolov5
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
    # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
    from utils.general import LOGGER

    def github_assets(repository, version='latest'):
...

然后你可以像这样下载它:

代码语言:javascript
复制
# load yolov7 method
from models.experimental import attempt_load

model = attempt_load('yolov7.pt', map_location='cuda:0')  # load FP32 model
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73578690

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档