首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch的grid_sample转换为CoreML (通过协同工具)

PyTorch的grid_sample转换为CoreML (通过协同工具)
EN

Stack Overflow用户
提问于 2021-03-20 19:29:58
回答 2查看 885关注 0票数 3

torch.nn.functional.grid_sample (源这里,单击文档文档)目前不受CoreML支持的操作(以及它们的转换实用程序库:协同工具)。

我正在寻找一种将下面所示的层从PyTorch的torchscript (docs 这里)导出到CoreML的方法(或者使用通过Swift创建的自定义op,或者通过grid_sample的高效PyTorch重写)。

获取详细信息和提示,请参阅提示部分

最小可验证示例

代码语言:javascript
复制
import coremltools as ct
import torch


class GridSample(torch.nn.Module):
    def forward(self, inputs, grid):
        # Rest could be the default behaviour, e.g. bilinear
        return torch.nn.functional.grid_sample(inputs, grid, align_corners=True)


# Image could also have more in_channels, different dimension etc.,
# for example (2, 32, 64, 64)
image = torch.randn(2, 3, 32, 32)  # (batch, in_channels, width, height)
grid = torch.randint(low=-1, high=2, size=(2, 64, 64, 2)).float()

layer = GridSample()
# You could use `torch.jit.script` if preferable
scripted = torch.jit.trace(layer, (image, grid))

# Sanity check
print(scripted(image, grid).shape)


# Error during conversion
coreml_layer = ct.converters.convert(
    scripted,
    source="pytorch",
    inputs=[
        ct.TensorType(name="image", shape=image.shape),
        ct.TensorType(name="grid", shape=grid.shape),
    ],
)

这将引发以下错误:

代码语言:javascript
复制
Traceback (most recent call last):
  File "/home/REDACTED/Downloads/sample.py", line 23, in <module>
    coreml_layer = ct.converters.convert(
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/_converters_entry.py", line 175, in convert
    mlmodel = mil_convert(
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 128, in mil_convert
    proto = mil_convert_to_proto(, convert_from, convert_to,
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 171, in mil_convert_to_proto
    prog = frontend_converter(, **kwargs)
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 85, in __call__
    return load(*args, **kwargs)
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 81, in load
    raise e
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 73, in load
    prog = converter.convert()
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 227, in convert
    convert_nodes(self.context, self.graph)
  File "/home/REDACTED/.conda/envs/REDACTED/lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 54, in convert_nodes
    raise RuntimeError(
RuntimeError: PyTorch convert function for op 'grid_sampler' not implemented.

依赖关系

Python (conda):

  • coremltools==4.1
  • torch==1.8.0

您还可以使用nightly/master构建(至少在写作当天:2021-03-20)。

提示

这些问题被分成两种可能的解决办法,我目前看到:

仅限PyTorch

torch.nn.functional.grid_sample 从零开始重写

  • 这将只需要在张量上坚持PyTorch操作,因为循环(例如三重嵌套)会挂起转换器,而且效率太低。
  • 您不能使用__getitem__ on list 或相关类型的--似乎与torch.Tensor一起工作,但是有问题,所以如果您得到RuntimeError: PyTorch convert function for op '__getitem__' not implemented,您应该记住它。

Pros:

  • 不需要两种语言&坚持单一技术

Cons:

  • 有循环限制,需要坚持矢量化操作(大部分/所有时间)

Swift & CoreML

负责运行grid_sample.的注册自定义层只有CPU才能实现(虽然使用Apple的Metal进行GPU加速会很棒)。

由于我对Swift不感兴趣,我收集了一些可能有助于您的资源:

Pros:

  • 使用循环和对算法进行更精细控制的可能性
  • 可能更容易一些,因为我们不局限于CoreML当前可以读取的操作

Cons:

  • 两种语言
  • 稀疏文档
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2022-05-30 15:21:58

显然,一些好的灵魂看到了我们的挣扎,并提供了自定义操作使用MIL (中间表示语言的CoreML)。

博客帖子,我在那里找到了解决方案和栅格样例

我不知道为什么OP没有在这里发布它,但如果您想要为您的解决方案采取一些这样的观点,请以评论的方式回复!

以下是完全操作转换代码:

代码语言:javascript
复制
from coremltools.converters.mil import register_torch_op, register_op
from coremltools.converters.mil.mil.ops.defs._op_reqs import *

# Custom operator for `torch.nn.functional.grid_sample`
@register_op(doc_str="Custom Grid Sampler", is_custom_op=True)
class custom_grid_sample(Operation):
    input_spec = InputSpec(
        x = TensorInputType(),
        grid = TensorInputType(),
        mode = StringInputType(const=True, optional=True),
        padding_mode = StringInputType(const=True, optional=True),
        align_corners = BoolInputType(const=True, optional=True)
    )

    bindings = {
        "class_name": "CustomGridSampler",
        "input_order": ["x", "grid"],
        "parameters": ["mode", "padding_mode", "align_corners"],
        "description": "Custom Grid Sampler"
    }

    def __init__(self, **kwargs):
        super(custom_grid_sample, self).__init__(**kwargs)

    def type_inference(self):
        x_type = self.x.dtype
        x_shape = self.x.shape

        grid_type = self.grid.dtype
        grid_shape = self.grid.shape

        assert len(x_shape) == len(grid_shape) == 4
        assert grid_shape[-1] == 2

        shape = list(x_shape)
        shape[-2] = grid_shape[1]
        shape[-1] = grid_shape[2]
        return types.tensor(x_type, tuple(shape))


@register_torch_op
def grid_sampler(context, node):
    inputs = _get_inputs(context, node)
    x = inputs[0]
    grid = inputs[1]
    mode = node.attr.get("mode", "bilinear")
    padding_mode = node.attr.get("padding_mode", "zeros")
    align_corners = node.attr.get("align_corners", False)
    x = mb.custom_grid_sample(
        x=x,
        grid=grid,
        mode=mode,
        padding_mode=padding_mode,
        align_corners=align_corners,
        name=node.name
    )
    context.add(x)
票数 1
EN

Stack Overflow用户

发布于 2021-03-21 12:39:01

这不是确切的答案,而是一些研究。grid_sample本质上是稀疏矩阵运算,其思想是尝试使其稠密。下面的代码演示了如何做到这一点。它可能很慢,并且要求grid是静态的,以便从模型转换中消除grid_sample,但是有点工作。

我们的目标是得到线性形式的变换。这里,为了得到稠密矩阵,我们给单位对角的‘网格_样本’,结果是矩阵保持变换,我们正在寻找。若要进行命名转换,请将平坦图像相乘到此矩阵。正如您在这里看到的batch=1,必须对每个grid独立地进行转换。

你的代码:

代码语言:javascript
复制
in_sz  = 2;    out_sz = 4;    batch  = 1;    ch     = 3

class GridSample(torch.nn.Module):
    def forward(self, inputs, grid):
        # Rest could be the default behaviour, e.g. bilinear
        return torch.nn.functional.grid_sample(inputs, grid, align_corners=True)

image = torch.randn( batch, ch, in_sz, in_sz)  # (batch, in_channels, width, height)
grid = torch.randint(low=-1, high=2, size=( batch, out_sz, out_sz, 2)).float()

layer = GridSample()
scripted = torch.jit.trace(layer, (image, grid))
print(scripted(image, grid))

退出:

代码语言:javascript
复制
tensor([[[[-0.8226, -0.4457, -0.3382, -0.0795],
          [-0.4457, -0.0052, -0.8226, -0.6341],
          [-0.4457, -0.8226, -0.4457, -0.6341],
          [-0.4510, -0.3382, -0.4457, -0.0424]],

         [[-1.0090, -1.6029, -1.3813, -0.1212],
          [-1.6029, -2.7920, -1.0090, -1.3060],
          [-1.6029, -1.0090, -1.6029, -1.3060],
          [-0.5651, -1.3813, -1.6029, -1.4566]],

         [[ 0.1482,  0.7313,  0.8916,  1.8723],
          [ 0.7313,  0.8144,  0.1482,  0.4398],
          [ 0.7313,  0.1482,  0.7313,  0.4398],
          [ 1.0103,  0.8916,  0.7313,  1.3434]]]])

改划:

代码语言:javascript
复制
oness  = torch.ones( in_sz*in_sz )
diagg  = torch.diag( oness ).reshape( 1, in_sz*in_sz, in_sz, in_sz )
denser = torch.nn.functional.grid_sample( diagg, grid, align_corners=True).reshape( in_sz*in_sz, out_sz*out_sz ).transpose(0,1)
print (denser.shape)
print (image.shape)
image_flat = image.reshape( batch, ch, in_sz*in_sz )
print (image_flat.shape)
print( torch.nn.functional.linear( image_flat, denser ).reshape( batch, ch, out_sz, out_sz ) )

退出:

代码语言:javascript
复制
torch.Size([16, 4])
torch.Size([1, 3, 2, 2])
torch.Size([1, 3, 4])
tensor([[[[-0.8226, -0.4457, -0.3382, -0.0795],
          [-0.4457, -0.0052, -0.8226, -0.6341],
          [-0.4457, -0.8226, -0.4457, -0.6341],
          [-0.4510, -0.3382, -0.4457, -0.0424]],

         [[-1.0090, -1.6029, -1.3813, -0.1212],
          [-1.6029, -2.7920, -1.0090, -1.3060],
          [-1.6029, -1.0090, -1.6029, -1.3060],
          [-0.5651, -1.3813, -1.6029, -1.4566]],

         [[ 0.1482,  0.7313,  0.8916,  1.8723],
          [ 0.7313,  0.8144,  0.1482,  0.4398],
          [ 0.7313,  0.1482,  0.7313,  0.4398],
          [ 1.0103,  0.8916,  0.7313,  1.3434]]]])

嗯,可能不是很有效,我希望这至少能让人觉得好笑。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66725654

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档