首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch值阈值化和所有其他值的零化

PyTorch值阈值化和所有其他值的零化
EN

Stack Overflow用户
提问于 2020-07-15 13:19:29
回答 2查看 1.4K关注 0票数 1

我在PyTorch中有二维张量,表示模型的信任。我要:

  • 如果行中的第二个值大于或等于阈值,则应将所有其他值更改为0
  • ,否则不应更改

简单的办法是:

  • 迭代行
  • ,如果值大于或等于,检查第2值
  • ,创建一行零,将第2值从行更改为第2值,替换行
  • ,否则不要做任何

F 215

然而,这是低效的。有没有一种矢量化/张力化的方法来做到这一点?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-07-15 13:43:28

我首先要构造一个新的零矩阵,然后根据需要将项目从矩阵移到零矩阵。复制第二个元素低于阈值的行中的所有行。对于所有其他行,只复制第二个元素。

代码语言:javascript
复制
import torch
threshold = .2

X = torch.rand((100, 10))
new = torch.zeros_like(X)
mask = X[:, 2] <= threshold
new[mask] = X[mask]
new[~mask, 2] = X[~mask, 2]
票数 2
EN

Stack Overflow用户

发布于 2020-07-15 16:34:49

尝尝这个

代码语言:javascript
复制
import numpy as np
x[(x[:,1] >= 0.5).nonzero(), np.r_[0, 2:x.shape[1]]] = 0.0

首先,使用(x[:,1] >= 0.5).nonzero()获取行索引,

然后取除第二列外的列索引np.r_[0, 2:x.shape[1]]

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62915838

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档