首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在TorchScript中跟踪张量大小

在TorchScript中跟踪张量大小
EN

Stack Overflow用户
提问于 2021-05-06 07:42:38
回答 2查看 2K关注 0票数 1

我正在通过PyTorch跟踪导出一个TorchScript模型,但我面临着一些问题。具体来说,我必须对张量大小执行一些操作,但是JIT编译器将变量形状硬编码为常量,与不同尺寸的张量兼容。

例如,创建类:

代码语言:javascript
复制
class Foo(nn.Module):
    """Toy class that plays with tensor shape to showcase tracing issue.

    It creates a new tensor with the same shape as the input one, except
    for the last dimension, which is doubled. This new tensor is filled
    based on the values of the input.
    """
    def __init__(self):
        nn.Module.__init__(self)

    def forward(self, x):
        new_shape = (x.shape[0], 2*x.shape[1])  # incriminated instruction
        x2 = torch.empty(size=new_shape)
        x2[:, ::2] = x
        x2[:, 1::2] = x + 1
        return x2

并运行测试代码:

代码语言:javascript
复制
x = torch.randn((3, 5))  # create example input

foo = Foo()
traced_foo = torch.jit.trace(foo, x)  # trace
print(traced_foo(x).shape)  # obviously this works
print(traced_foo(x[:, :4]).shape)  # but fails with a different shape!

我可以通过脚本来解决这个问题,但在这种情况下,我确实需要使用跟踪。此外,我认为跟踪应该能够正确地处理张量大小的操纵。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-05-06 09:44:34

,但在这种情况下,我确实需要使用跟踪

您可以在任何需要的地方自由地混合torch.scripttorch.jit。例如,可以这样做:

代码语言:javascript
复制
import torch


class MySuperModel(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.scripted = torch.jit.script(Foo(*args, **kwargs))
        self.traced = Bar(*args, **kwargs)

    def forward(self, data):
        return self.scripted(self.traced(data))

model = MySuperModel()
torch.jit.trace(model, (input1, input2))

您还可以将依赖于形状的部分功能移动到分离函数并用@torch.jit.script装饰它。

代码语言:javascript
复制
@torch.jit.script
def _forward_impl(x):
    new_shape = (x.shape[0], 2*x.shape[1])  # incriminated instruction
    x2 = torch.empty(size=new_shape)
    x2[:, ::2] = x
    x2[:, 1::2] = x + 1
    return x2

class Foo(nn.Module):
    def forward(self, x):
        return _forward_impl(x)

因为它必须理解您的代码,因此除了script之外,没有其他方法了。通过跟踪,它只记录您在张量上执行的操作,并且不了解依赖于数据(或数据形状)的控制流。

无论如何,这应该涵盖大多数的情况,如果没有,你应该更具体。

票数 1
EN

Stack Overflow用户

发布于 2022-04-20 04:15:00

此错误已修复为1.10.2。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67413808

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档