假设蒙版如下所示:
mask = torch.tensor([
[True, True, False, True, False],
[True, False, True, True, True ],
])我想在每一行中用顺序值对True值进行编号。我不关心False点中的内容,所以为了简单起见,0。因此,所期望的结果是
tensor([[0, 1, 0, 2, 0], # 0 1 _ 2 _
[0, 0, 1, 2, 3]]) # 0 _ 1 2 3我希望这能奏效:
replacements = torch.arange(mask.size(1)).expand(mask.size())
target = torch.zeros(mask.size(), dtype=int)
target.masked_scatter(mask, replacements)不幸的是,masked_scatter忽略了replacements的形状,因此这段代码的结果是:
tensor([[0, 1, 0, 2, 0], # 0 1 _ 2 _
[3, 0, 4, 0, 1]]) # 3 _ 4 0 1我需要做些什么呢?
发布于 2020-05-19 05:13:58
我会尝试一下torch.cumsum:torch.cumsum(mask,dim=1) -1) * mask
完整的例子
import torch
mask = torch.tensor([
[True, True, False, True, False],
[True, False, True, True, True ],
])
result=torch.cumsum(mask,dim=1) -1) * mask
print(result)这将打印:
tensor([[0, 1, 0, 2, 0],
[0, 0, 1, 2, 3]])https://stackoverflow.com/questions/61883535
复制相似问题