首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >masked_scatter但是行吗?

masked_scatter但是行吗?
EN

Stack Overflow用户
提问于 2020-05-19 04:56:14
回答 1查看 485关注 0票数 2

假设蒙版如下所示:

代码语言:javascript
复制
mask = torch.tensor([
  [True,  True,  False, True,  False],
  [True,  False, True,  True,  True ],
])

我想在每一行中用顺序值对True值进行编号。我不关心False点中的内容,所以为了简单起见,0。因此,所期望的结果是

代码语言:javascript
复制
tensor([[0, 1, 0, 2, 0],    # 0 1 _ 2 _
        [0, 0, 1, 2, 3]])   # 0 _ 1 2 3

我希望这能奏效:

代码语言:javascript
复制
replacements = torch.arange(mask.size(1)).expand(mask.size())
target = torch.zeros(mask.size(), dtype=int)
target.masked_scatter(mask, replacements)

不幸的是,masked_scatter忽略了replacements的形状,因此这段代码的结果是:

代码语言:javascript
复制
tensor([[0, 1, 0, 2, 0],    # 0 1 _ 2 _
        [3, 0, 4, 0, 1]])   # 3 _ 4 0 1

我需要做些什么呢?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-05-19 05:13:58

我会尝试一下torch.cumsumtorch.cumsum(mask,dim=1) -1) * mask

完整的例子

代码语言:javascript
复制
import torch
mask = torch.tensor([
  [True,  True,  False, True,  False],
  [True,  False, True,  True,  True ],
])
result=torch.cumsum(mask,dim=1) -1) * mask
print(result)

这将打印:

代码语言:javascript
复制
tensor([[0, 1, 0, 2, 0],
        [0, 0, 1, 2, 3]])
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61883535

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档