首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用脚本转换移相器变压器?

如何使用脚本转换移相器变压器?
EN

Stack Overflow用户
提问于 2020-06-09 15:31:16
回答 1查看 760关注 0票数 1

我正在尝试编译Pytory转换器,以便在C++中运行它:

代码语言:javascript
复制
from torch.nn import TransformerEncoder, TransformerEncoderLayer
encoder_layers = TransformerEncoderLayer(1000, 8, 512, 0.1)
transf = TransformerEncoder(encoder_layers, 6)
sm = torch.jit.script(transf)

但是我发现了一个错误:

代码语言:javascript
复制
RuntimeError:  Expected a default value of type Tensor on parameter
"src_mask":   File "C:\Program Files (x86)\Microsoft Visual
Studio\Shared\Python36_64\lib\site-packages\torch\nn\modules\transformer.py",
line 271
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~... <--- HERE
        r"""Pass the input through the encoder layer.

看上去好像是火把变压器模块出了问题。

有没有办法在C++中运行火炬变压器?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-10 00:40:56

您需要升级到PyTorch 1.5.0,以前的版本不支持将Transformers转换为TorchScript (JIT)模块。

代码语言:javascript
复制
pip install torch===1.5.0 -f https://download.pytorch.org/whl/torch_stable.html

在1.5.0中,您将看到一些有关参数声明为常量的警告,如:

代码语言:javascript
复制
UserWarning: 'q_proj_weight' was found in ScriptModule constants,  but it is a non-constant parameter. Consider removing it.

这些都可以被安全地忽略。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62286188

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档