我的目标是在一个没有定义神经网络的原始类的环境中序列化一个经过训练的模型,一个负载。为了实现这一点,我决定使用TorchScript,因为这似乎是唯一可行的方法。
我有一个多任务模型(nn.Module类型),它使用每个任务的公共体(也是nn.Module,只有几个线性层)和一组线性头模型,每个任务一个。我将头模型存储在一个名为Dict[int, nn.Module]的字典_task_head_models中,并在我的模块类中创建了一个ad转发方法,以便在预测时选择正确的头:
def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
if task_id not in self._task_head_models.keys():
raise ValueError(
f"The task id {task_id} is not valid. Valid task ids are {self._task_head_models.keys()}."
)
return self._task_head_models[task_id](self._model(x))这很好,直到我没有尝试使用torchscript序列化它。当我尝试torch.jit.script(mymodule)时,我得到:
Module 'MyModule' has no attribute '_task_head_models' (This attribute exists on the Python module, but we failed to convert Python type: 'dict' to a TorchScript type. Cannot infer concrete type of torch.nn.Module. Its type was inferred; try adding a type annotation for the attribute.)似乎没有的是,我的模块包含一个Dict,而不是错误消息中提到的dict。暂时忘记了,现在还不清楚为什么会发生这种情况。在语言引用:https://docs.w3cub.com/pytorch/jit_language_reference.html中似乎支持字典。
我还尝试使用ModuleDict而不是Dict (将键类型改为str),但这似乎也不起作用:Unable to extract string literal index. ModuleDict indexing is only supported with string literals. Enumeration of ModuleDict is supported, e.g. 'for k, v in self.items(): ...':
发布于 2022-09-13 16:42:54
如果Dict _task_head_models中没有很多项,我认为使用if-else分支可以帮助您。示例代码如下:
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self._task_head0 = torch.nn.Linear(3, 24)
self._task_head1 = torch.nn.Linear(3, 24)
def forward(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor:
if task_id == 0:
return self._task_head0(x)
elif task_id == 1:
return self._task_head1(x)
else:
raise ValueError(
f"The task id {task_id} is not valid. Valid task ids are 0,1."
)https://stackoverflow.com/questions/73290107
复制相似问题