首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用字典替换NumPy数组中的值会给出不明确的结果,为什么呢?

用字典替换NumPy数组中的值会给出不明确的结果,为什么呢?
EN

Stack Overflow用户
提问于 2020-11-29 13:43:28
回答 1查看 21关注 0票数 1

因此,我有一个数组,其中包含一些单词,我正在尝试执行一次热编码。

假设输入是AI DSA DSA AI ML ML AI DS DS AI C AI ML ML C

这是我的密码:

代码语言:javascript
复制
def apply_one_hot_encoding(X):
    dic = {}
    k = sorted(list(set(X)))
    for i in range(len(k)):
        arr = ['0' for i in range(len(k))]
        arr[i] = '1'
        dic[k[i]] = ''.join(arr)
    
    for i in range(len(X)):
        t = dic[X[i]]
        X[i] = t
         
    return X

if __name__ == "__main__":
    X = np.array(list(input().split()))
    
    one_hot_encoded_array = apply_one_hot_encoding(X)
    for i in one_hot_encoded_array:
        print(*i)

现在,我希望输出如下:

代码语言:javascript
复制
1 0 0 0 0 
0 0 0 1 0 
0 0 1 0 0 

但我得到的是:

代码语言:javascript
复制
1 0 0
0 0 1
1 0 0

如果我将t值附加到另一个列表并返回该列表,它将给出正确的结果。

为什么在直接替换的情况下,赋值被裁剪为仅3个字符?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-11-29 16:15:14

这个问题是由Numpy数组的dtype (数据类型)引起的。

当您使用print(X.dtype)在上面的程序中检查numy数组的数据类型时,它将数据类型显示为<U3,它只能容纳numpy数组X中每个元素的三个字符。

由于输入数组包含五个类别,因此可以通过dtype将数组的<U5转换为X = np.array(list(input().split()), dtype='<U5'),该X = np.array(list(input().split()), dtype='<U5')可以为numpy数组X中的每个元素保留最多5个字符。

修正的代码是,

代码语言:javascript
复制
def apply_one_hot_encoding(X):
    dic = {}
    k = sorted(list(set(X)))
    for i in range(len(k)):
        arr = ['0' for i in range(len(k))]
        arr[i] = '1'
        dic[k[i]] = ''.join(arr)
    
    for i in range(len(X)):
        t = dic[X[i]]
        X[i] = t
         
    return X

if __name__ == "__main__":
    X = np.array(list(input().split()),dtype = '<U5')
    
    one_hot_encoded_array = apply_one_hot_encoding(X)
    for i in one_hot_encoded_array:
        print(*i)

当将值存储在单独的numpy数组中时,不需要上述方法,因为numpy会根据字符串的大小自动更改数据类型,

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65060768

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档