首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >one-hot ID矩阵的索引数组

one-hot ID矩阵的索引数组
EN

Stack Overflow用户
提问于 2020-04-29 06:20:05
回答 3查看 35关注 0票数 0

假设我有一个数组0,2,我想输出一个由基于0,2的一个热向量组成的矩阵,如[ 1,0,0] (请注意,输出矩阵的第二维假设为3,但它可以是大于argmax(0,2)的任何数字,即2。

我只能想到用这种方式来实现这个功能。有没有更简单的方法。

代码语言:javascript
复制
t = torch.tensor([0,2])
dim2_size = 3
id_t = torch.zeros(t.shape[0], dim2_size)
row_idx = 0
for i in t:
  col_idx = i.item()
  id_t[row_idx, col_idx] = 1
  row_idx += 1
id_t
EN

回答 3

Stack Overflow用户

发布于 2020-04-29 06:52:05

这个没有使用任何循环。

代码语言:javascript
复制
import torch

labels = torch.tensor([0, 2])
one_hot = torch.zeros(labels.shape[0], torch.max(labels)+1)
one_hot[torch.arange(labels.shape[0]), labels] = 1

print(one_hot)
代码语言:javascript
复制
tensor([[1., 0., 0.],
        [0., 0., 1.]])
票数 0
EN

Stack Overflow用户

发布于 2020-04-29 09:01:46

通过numpy的方法更加简单

代码语言:javascript
复制
import torch
import numpy as np

labels =[0,2]
output=np.eye(max(labels)+1)[labels]
print(torch.from_numpy(output))
票数 0
EN

Stack Overflow用户

发布于 2020-04-29 11:05:41

在Pytorch中,这最好通过使用scatter_来完成。

代码语言:javascript
复制
t = torch.tensor([0,2]).unsqueeze(0)
num_dims = 3
id_t = torch.zeros(num_dims, t.shape[1]).scatter_(0, t, 1)

这将为您提供id_t为:

代码语言:javascript
复制
tensor([[1., 0.],
        [0., 0.],
        [0., 1.]])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61490959

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档