首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Torchscript与用于张量列表的torch.cat不兼容

Torchscript与用于张量列表的torch.cat不兼容
EN

Stack Overflow用户
提问于 2020-02-28 01:43:41
回答 1查看 464关注 0票数 0

当在torchscript中使用时,Torch.cat会抛出张量列表错误

以下是重现错误的最小可重现示例

代码语言:javascript
复制
import torch
import torch.nn as nn

"""
Smallest working bug for torch.cat torchscript
"""


class Model(nn.Module):
    """dummy model for showing error"""

    def __init__(self):
        super(Model, self).__init__()
        pass

    def forward(self):
        a = torch.rand([6, 1, 12])
        b = torch.rand([6, 1, 12])
        out = torch.cat([a, b], axis=2)
        return out


if __name__ == '__main__':
    model = Model()
    print(model())  # works
    torch.jit.script(model)  # throws error

预期的结果将是torch.cat的火炬脚本输出。下面是提供的错误消息:

代码语言:javascript
复制
File "/home/anil/.conda/envs/rnn/lib/python3.7/site-packages/torch/jit/__init__.py", line 1423, in _create_methods_from_stubs
    self._c._create_methods(self, defs, rcbs, defaults)
RuntimeError: 
Arguments for call are not valid.
The following operator variants are available:

  aten::cat(Tensor[] tensors, int dim=0) -> (Tensor):
  Keyword argument axis unknown.

  aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> (Tensor(a!)):
  Argument out not provided.

The original call is:
at smallest_working_bug_torch_cat_torchscript.py:19:14
    def forward(self):
        a = torch.rand([6, 1, 12])
        b = torch.rand([6, 1, 12])
        out = torch.cat([a, b], axis=2)
              ~~~~~~~~~ <--- HERE
        return out

请让我知道这个问题的解决方法或解决办法。

谢谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-02-28 03:12:36

axis更改为dim可以修复错误,最初的解决方案发布在here

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60438983

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档