首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TorchScript模型导出中的重复分支

TorchScript模型导出中的重复分支
EN

Stack Overflow用户
提问于 2021-04-25 11:15:31
回答 1查看 169关注 0票数 0

我正试图通过脚本将PyTorch模型导出到TorchScript,结果我陷入了困境。我创建了一个玩具类来展示这个问题:

代码语言:javascript
复制
import torch
from torch import nn


class SadModule(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self, use_skip: bool):
        nn.Module.__init__(self)
        self.use_skip = use_skip
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        if self.use_skip:
            x_input = x
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
        return x

它基本上只有一个线性层和一个可选的跳过连接。如果我尝试用

代码语言:javascript
复制
mod1 = SadModule(False)
scripted_mod1 = torch.jit.script(mod)

我得到以下错误:

代码语言:javascript
复制
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-10-a7ebc7af32c7> in <module>
----> 1 scripted_mod1 = torch.jit.script(mod)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
    895
    896     if isinstance(obj, torch.nn.Module):
--> 897         return torch.jit._recursive.create_script_module(
    898             obj, torch.jit._recursive.infer_methods_to_compile
    899         )

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    350     check_module_initialized(nn_module)
    351     concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    353
    354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    408     # Compile methods if necessary
    409     if concrete_type not in concrete_type_store.methods_compiled:
--> 410         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    411         torch._C._run_emit_module_hook(cpp_module)
    412         concrete_type_store.methods_compiled.add(concrete_type)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    302     property_rcbs = [p.resolution_callback for p in property_stubs]
    303
--> 304     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    305
    306

RuntimeError:

x_input is not defined in the false branch:
  File "<ipython-input-7-d08ed7ff42ec>", line 12
    def forward(self, x):
        if self.use_skip:
        ~~~~~~~~~~~~~~~~~
            x_input = x
            ~~~~~~~~~~~ <--- HERE
        x = self.layer(x)
        if self.use_skip:
and was used here:
  File "<ipython-input-7-d08ed7ff42ec>", line 16
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
                    ~~~~~~~ <--- HERE
        return x

因此,基本上TorchScript无法认识到,对于mod1,任何一个if语句的True分支都不会被使用。此外,如果我们创建一个实际使用skip连接的实例,

代码语言:javascript
复制
mod2 = SadModule(True)
scripted_mod2 = torch.jit.script(mod2)

我们将得到另一个错误:

代码语言:javascript
复制
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-b5ca61d8aa73> in <module>
----> 1 scripted_mod2 = torch.jit.script(mod2)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_script.py in script(obj, optimize, _frames_up, _rcb)
    895
    896     if isinstance(obj, torch.nn.Module):
--> 897         return torch.jit._recursive.create_script_module(
    898             obj, torch.jit._recursive.infer_methods_to_compile
    899         )

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    350     check_module_initialized(nn_module)
    351     concrete_type = get_module_concrete_type(nn_module, share_types)
--> 352     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    353
    354 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    408     # Compile methods if necessary
    409     if concrete_type not in concrete_type_store.methods_compiled:
--> 410         create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    411         torch._C._run_emit_module_hook(cpp_module)
    412         concrete_type_store.methods_compiled.add(concrete_type)

~/Software/miniconda3/envs/pytorch3d/lib/python3.8/site-packages/torch/jit/_recursive.py in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    302     property_rcbs = [p.resolution_callback for p in property_stubs]
    303
--> 304     concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
    305
    306

RuntimeError:

x_input is not defined in the false branch:
  File "<ipython-input-18-ac8b9713c789>", line 17
    def forward(self, x):
        if self.use_skip:
        ~~~~~~~~~~~~~~~~~
            x_input = x
            ~~~~~~~~~~~ <--- HERE
        x = self.layer(x)
        if self.use_skip:
and was used here:
  File "<ipython-input-18-ac8b9713c789>", line 21
        x = self.layer(x)
        if self.use_skip:
            x = x + x_input
                    ~~~~~~~ <--- HERE
        return x

因此,在这种情况下,TorchScript不理解两个if都是真的,而且实际上x_input定义得很好。

为了避免这个问题,我可以将类分成两个子类,如下所示:

代码语言:javascript
复制
class SadModuleNoSkip(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self):
        nn.Module.__init__(self)
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        x = self.layer(x)
        return x

class SadModuleSkip(nn.Module):
    """Takes a (*, 2) input and runs it thorugh a linear layer. Can optionally
    use a skip connection. The usage of the skip connection or not is an
    architectural choice.

    """
    def __init__(self):
        nn.Module.__init__(self)
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        x_input = x
        x = self.layer(x)
        x = x + x_input
        return x

然而,我正在开发一个庞大的代码库,我将不得不对许多类重复这个过程,这很费时,可能会引入bug。而且,我正在处理的模块通常是巨大的卷积网,而if只是控制一个额外的批规范化的存在。在我看来,除了单个批处理规范层之外,在99%的块中必须有相同的类是不可取的。

有什么方法可以帮助TorchScript处理分支吗?

编辑:添加了一个最小可行的示例。

更新:即使我将提示use_skip输入为常量,也不能工作

代码语言:javascript
复制
from typing import Final

class SadModule(nn.Module):
    use_skip: Final[bool]
    ...
EN

回答 1

Stack Overflow用户

发布于 2021-05-07 13:03:34

我开了一个关于GitHub的问题。项目维护人员解释说,使用Final是可行的。但是要小心,因为到今天(2021年5月7日),这个特性仍在开发中(abeit正处于最后阶段,参见这里的特征跟踪器)。

尽管官方版本中还没有该版本,但它存在于PyTorch的夜间版本中,因此您可以以在网站上解释的形式安装pytorch-nighly构建(向下滚动到Install ,然后选择Preview (夜间),或者等待下一个版本。

对于几个月后阅读这个答案的人来说,这个特性应该已经集成到PyTorch的主要版本中了。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67252744

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档