由于Python builtin <built-in function sum> is currently not supported in Torchscript:,我正在寻找一种推荐的方法来执行以下操作:
class Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return sum(x.tolist())
model = Model()
model = torch.jit.script(model)
model(torch.arange(10))发布于 2021-05-18 08:23:44
最简单的方法是直接使用PyTorch的sum:
class Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.sum(x)如果出于某些原因,这不是一个选项,您必须使用类型规范和显式循环(请注意类型提示!):
import typing
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x) -> int:
x: typing.List[int] = x.tolist()
result = 0
for elem in x:
result += elem
return resulthttps://stackoverflow.com/questions/67576054
复制相似问题