首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >包含多个头的TorchScript模型

包含多个头的TorchScript模型
EN

Stack Overflow用户
提问于 2022-08-09 10:14:11
回答 1查看 190关注 0票数 2

我的目标是在一个没有定义神经网络的原始类的环境中序列化一个经过训练的模型,一个负载。为了实现这一点,我决定使用TorchScript,因为这似乎是唯一可行的方法。

我有一个多任务模型(nn.Module类型),它使用每个任务的公共体(也是nn.Module,只有几个线性层)和一组线性头模型,每个任务一个。我将头模型存储在一个名为Dict[int, nn.Module]的字典_task_head_models中,并在我的模块类中创建了一个ad转发方法,以便在预测时选择正确的头:

代码语言:javascript
复制
    def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
        if task_id not in self._task_head_models.keys():
            raise ValueError(
                f"The task id {task_id} is not valid. Valid task ids are {self._task_head_models.keys()}."
            )

        return self._task_head_models[task_id](self._model(x))

这很好,直到我没有尝试使用torchscript序列化它。当我尝试torch.jit.script(mymodule)时,我得到:

代码语言:javascript
复制
Module 'MyModule' has no attribute '_task_head_models' (This attribute exists on the Python module, but we failed to convert Python type: 'dict' to a TorchScript type. Cannot infer concrete type of torch.nn.Module. Its type was inferred; try adding a type annotation for the attribute.)

似乎没有的是,我的模块包含一个Dict,而不是错误消息中提到的dict。暂时忘记了,现在还不清楚为什么会发生这种情况。在语言引用:https://docs.w3cub.com/pytorch/jit_language_reference.html中似乎支持字典。

我还尝试使用ModuleDict而不是Dict (将键类型改为str),但这似乎也不起作用:Unable to extract string literal index. ModuleDict indexing is only supported with string literals. Enumeration of ModuleDict is supported, e.g. 'for k, v in self.items(): ...':

EN

回答 1

Stack Overflow用户

发布于 2022-09-13 16:42:54

如果Dict _task_head_models中没有很多项,我认为使用if-else分支可以帮助您。示例代码如下:

代码语言:javascript
复制
import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._task_head0 = torch.nn.Linear(3, 24)
        self._task_head1 = torch.nn.Linear(3, 24)

    def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
      if task_id == 0:
          return self._task_head0(x)
      elif task_id == 1:
          return self._task_head1(x)
      else:
          raise ValueError(
                f"The task id {task_id} is not valid. Valid task ids are 0,1."
            )
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73290107

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档