首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TorchScript需要源代码访问才能对collections.deque进行编译

TorchScript需要源代码访问才能对collections.deque进行编译
EN

Stack Overflow用户
提问于 2021-03-14 20:01:16
回答 1查看 1.4K关注 0票数 3

我正在尝试将PyTorch 福姆模型转换为TorchScript。当我开始用@torch.jit.script注释一些类时,我就得到了一个错误:

OSError: Can't get source for <class 'collections.deque'>. TorchScript requires source access in order to carry out compilation, make sure original .py files are available.

因此,据我所知,在CPython中实现的类不能被TorchScript编译器读取。我没有找到任何纯Python实现。我怎样才能克服这个问题?

下面是我要注释的课程:

代码语言:javascript
复制
import queue
import collections
import threading
import torch

@torch.jit.script
class SyncMaster(object):
    """An abstract `SyncMaster` object.

    - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
    call `register(id)` and obtain an `SlavePipe` to communicate with the master.
    - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
    and passed to a registered callback.
    - After receiving the messages, the master device should gather the information and determine to message passed
    back to each slave devices.
    """

    def __init__(self, master_callback):
        """

        Args:
            master_callback: a callback to be invoked after having collected messages from slave devices.
        """
        self._master_callback = master_callback
        self._queue = queue.Queue()
        self._registry = collections.OrderedDict()
        self._activated = False

    def __getstate__(self):
        return {'master_callback': self._master_callback}

    def __setstate__(self, state):
        self.__init__(state['master_callback'])

    def register_slave(self, identifier):
        """
        Register an slave device.

        Args:
            identifier: an identifier, usually is the device id.

        Returns: a `SlavePipe` object which can be used to communicate with the master device.

        """
        if self._activated:
            assert self._queue.empty(), 'Queue is not clean before next initialization.'
            self._activated = False
            self._registry.clear()
        future = FutureResult()
        self._registry[identifier] = _MasterRegistry(future)
        return SlavePipe(identifier, self._queue, future)

    def run_master(self, master_msg):
        """
        Main entry for the master device in each forward pass.
        The messages were first collected from each devices (including the master device), and then
        an callback will be invoked to compute the message to be sent back to each devices
        (including the master device).

        Args:
            master_msg: the message that the master want to send to itself. This will be placed as the first
            message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.

        Returns: the message to be sent back to the master device.

        """
        self._activated = True

        intermediates = [(0, master_msg)]
        for i in range(self.nr_slaves):
            intermediates.append(self._queue.get())

        results = self._master_callback(intermediates)
        assert results[0][0] == 0, 'The first result should belongs to the master.'

        for i, res in results:
            if i == 0:
                continue
            self._registry[i].result.put(res)

        for i in range(self.nr_slaves):
            assert self._queue.get() is True

        return results[0][1]

    @property
    def nr_slaves(self):
        return len(self._registry)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-03-15 21:55:41

将TorchScript生成方法从torch.jit.script切换到torch.jit.trace,它工作正常,无需对任何内容进行注释。另外,torch.onnx.export有时也能工作。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66628965

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档