首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >是否可以修改YOLOv8,将其用作其他任务的特性提取器?

是否可以修改YOLOv8,将其用作其他任务的特性提取器?
EN

Data Science用户
提问于 2023-03-30 21:56:05
回答 1查看 105关注 0票数 1

我正在阅读YOLOv8 这里的文档,但我找不到一种简单的方法来完成标题中的建议。我想要做的是加载一个经过预先训练的YOLOv8模型,创建一个包含YOLOv8作为子模块的更大的模型,并修改YOLOv8的前向函数,以便我可以访问对象检测丢失和卷积特性,以便它们可以用于为其他自定义任务提供后续层。

为了使事情更清楚,yolov8最初是按照https://docs.ultralytics.com/quickstart/#use-with-python使用的:

代码语言:javascript
复制
from ultralytics import YOLO

# Create a new YOLO model from scratch
model = YOLO('yolov8n.yaml')

# Load a pretrained YOLO model (recommended for training)
model = YOLO('yolov8n.pt')

# Train the model using the 'coco128.yaml' dataset for 3 epochs
results = model.train(data='coco128.yaml', epochs=3)

# Evaluate the model's performance on the validation set
results = model.val()

# Perform object detection on an image using the model
results = model('https://ultralytics.com/images/bus.jpg')

# Export the model to ONNX format
success = model.export(format='onnx')

相反,我想做这样的事情:

代码语言:javascript
复制
import torch
import torch.nn as nn
from ultralytics import YOLO

class Yolov8Wrapper(nn.Module):
    
    def __init__(self, yolov8_feature_dim, n1, n2, n3):
        super().__init__()
        self.yolov8 = YOLO('yolov8n.pt')
        self.fc1 = nn.Linear(yolov8_feature_dim, n1)
        self.fc2 = nn.Linear(yolov8_feature_dim, n2)
        self.fc3 = nn.Linear(yolov8_feature_dim, n3)
    
    def forward(self, images, gt_boxes):
        features, loss = self.yolov8(images, gt_boxes)
        logits1 = self.fc1(features)
        logits2 = self.fc2(features)
        logits3 = self.fc3(features)
        return {
            'logits1': logits1,
            'logits2': logits2,
            'logits3': logits3,
            'yolov8_loss': loss,
        }

上面的代码是一个非常简单的草图,当然不会起作用,但或多或少就是这样。此外,通过创建这个临时包装器,我将无法使用开箱即用的功能来训练、验证和预测YOLO库附带的功能,因为它将是一个自定义体系结构(YOLOv8只是它的一个子模块)。因此,我还需要弄清楚如何编写一个定制的数据中心,以便向YOLOv8提供它所期望的输入以及我的包装器所需的附加内容(包装器中可能有不同的附加层,根据YOLOv8 8的S特性预测不同的输出,将其视为多任务学习)。

这个是可能的吗?这是如何做到的呢?

EN

回答 1

Data Science用户

发布于 2023-03-30 23:20:16

只是对包装器的一些建议:

代码语言:javascript
复制
import torch
import torch.nn as nn
from ultralytics import YOLO

class Yolov8Wrapper(nn.Module):
    
    def __init__(self, yolov8_feature_dim, n1, n2, n3):
        super().__init__()
        self.yolov8 = YOLO('yolov8n.pt')
        self.fc1 = nn.Linear(yolov8_feature_dim, n1)
        self.fc2 = nn.Linear(yolov8_feature_dim, n2)
        self.fc3 = nn.Linear(yolov8_feature_dim, n3)
    
    def forward(self, images, gt_boxes):
        class CustomDataset(torch.utils.data.Dataset):
            def __init__(self, images, gt_boxes):
                self.images = images
                self.gt_boxes = gt_boxes
                
            def __len__(self):
                return len(self.images)
            
            def __getitem__(self, idx):
                image = self.images[idx]
                gt_box = self.gt_boxes[idx]
                return (image, gt_box)
        
        custom_dataloader = torch.utils.data.DataLoader(CustomDataset(images, gt_boxes), batch_size=1)
        
        with torch.no_grad():
            for batch_idx, (image, gt_box) in enumerate(custom_dataloader):
                features, loss = self.yolov8(image, gt_box)
        
        logits1 = self.fc1(features)
        logits2 = self.fc2(features)
        logits3 = self.fc3(features)
        
        return {
            'logits1': logits1,
            'logits2': logits2,
            'logits3': logits3,
            'yolov8_loss': loss,
        }
```
代码语言:javascript
复制
票数 0
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/120596

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档