我正试图通过脚本将PyTorch模型导出到TorchScript,结果我陷入了困境。我创建了一个玩具类来展示这个问题:
import torch
from torch import nn
class SadModule(nn.Module):
"""Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
use a skip connection. The usage of the skip connection or not is an
architectural choice.
"""
def __init__(self, use_skip: bool):
nn.Module.__init__(self)
self.use_skip = use_skip
self.layer = nn.Linear(2, 2)
def forward(self, x):
if self.use_skip:
x_input = x
x = self.layer(x)
if self.use_skip:
x = x + x_input
return x它基本上只有一个线性层和一个可选的跳过连接。如果我尝试用
mod1 = SadModule(False)
scripted_mod1 = torch.jit.script(mod)我得到以下错误:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-10-a7ebc7af32c7> in <module>
----> 1 scripted_mod1 = torch.jit.script(mod)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
895
896 if isinstance(obj, torch.nn.Module):
--> 897 return torch.jit._recursive.create_script_module(
898 obj, torch.jit._recursive.infer_methods_to_compile
899 )
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
350 check_module_initialized(nn_module)
351 concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
353
354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
408 # Compile methods if necessary
409 if concrete_type not in concrete_type_store.methods_compiled:
--> 410 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
411 torch._C._run_emit_module_hook(cpp_module)
412 concrete_type_store.methods_compiled.add(concrete_type)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
302 property_rcbs = [p.resolution_callback for p in property_stubs]
303
--> 304 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
305
306
RuntimeError:
x_input is not defined in the false branch:
File "<ipython-input-7-d08ed7ff42ec>", line 12
def forward(self, x):
if self.use_skip:
~~~~~~~~~~~~~~~~~
x_input = x
~~~~~~~~~~~ <--- HERE
x = self.layer(x)
if self.use_skip:
and was used here:
File "<ipython-input-7-d08ed7ff42ec>", line 16
x = self.layer(x)
if self.use_skip:
x = x + x_input
~~~~~~~ <--- HERE
return x因此,基本上TorchScript无法认识到,对于mod1,任何一个if语句的True分支都不会被使用。此外,如果我们创建一个实际使用skip连接的实例,
mod2 = SadModule(True)
scripted_mod2 = torch.jit.script(mod2)我们将得到另一个错误:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-21-b5ca61d8aa73> in <module>
----> 1 scripted_mod2 = torch.jit.script(mod2)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
895
896 if isinstance(obj, torch.nn.Module):
--> 897 return torch.jit._recursive.create_script_module(
898 obj, torch.jit._recursive.infer_methods_to_compile
899 )
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
350 check_module_initialized(nn_module)
351 concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352 return create_script_module_impl(nn_module, concrete_type, stubs_fn)
353
354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
408 # Compile methods if necessary
409 if concrete_type not in concrete_type_store.methods_compiled:
--> 410 create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
411 torch._C._run_emit_module_hook(cpp_module)
412 concrete_type_store.methods_compiled.add(concrete_type)
~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
302 property_rcbs = [p.resolution_callback for p in property_stubs]
303
--> 304 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
305
306
RuntimeError:
x_input is not defined in the false branch:
File "<ipython-input-18-ac8b9713c789>", line 17
def forward(self, x):
if self.use_skip:
~~~~~~~~~~~~~~~~~
x_input = x
~~~~~~~~~~~ <--- HERE
x = self.layer(x)
if self.use_skip:
and was used here:
File "<ipython-input-18-ac8b9713c789>", line 21
x = self.layer(x)
if self.use_skip:
x = x + x_input
~~~~~~~ <--- HERE
return x因此,在这种情况下,TorchScript不理解两个if都是真的,而且实际上x_input定义得很好。
为了避免这个问题,我可以将类分成两个子类,如下所示:
class SadModuleNoSkip(nn.Module):
"""Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
use a skip connection. The usage of the skip connection or not is an
architectural choice.
"""
def __init__(self):
nn.Module.__init__(self)
self.layer = nn.Linear(2, 2)
def forward(self, x):
x = self.layer(x)
return x
class SadModuleSkip(nn.Module):
"""Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
use a skip connection. The usage of the skip connection or not is an
architectural choice.
"""
def __init__(self):
nn.Module.__init__(self)
self.layer = nn.Linear(2, 2)
def forward(self, x):
x_input = x
x = self.layer(x)
x = x + x_input
return x然而,我正在开发一个庞大的代码库,我将不得不对许多类重复这个过程,这很费时,可能会引入bug。而且,我正在处理的模块通常是巨大的卷积网,而if只是控制一个额外的批规范化的存在。在我看来,除了单个批处理规范层之外,在99%的块中必须有相同的类是不可取的。
有什么方法可以帮助TorchScript处理分支吗?
编辑:添加了一个最小可行的示例。
更新:即使我将提示use_skip输入为常量,也不能工作
from typing import Final
class SadModule(nn.Module):
use_skip: Final[bool]
...发布于 2021-05-07 13:03:34
我开了一个关于GitHub的问题。项目维护人员解释说,使用Final是可行的。但是要小心,因为到今天(2021年5月7日),这个特性仍在开发中(abeit正处于最后阶段,参见这里的特征跟踪器)。
尽管官方版本中还没有该版本,但它存在于PyTorch的夜间版本中,因此您可以以在网站上解释的形式安装pytorch-nighly构建(向下滚动到Install ,然后选择Preview (夜间),或者等待下一个版本。
对于几个月后阅读这个答案的人来说,这个特性应该已经集成到PyTorch的主要版本中了。
https://stackoverflow.com/questions/67252744
复制相似问题