首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >向one-hot编码的张量转换

向one-hot编码的张量转换
EN

Stack Overflow用户
提问于 2018-09-08 20:02:29
回答 1查看 42关注 0票数 0

我有一个标量的NxTxW张量;我正在尝试找出如何使用numpy将其转换为one-hot编码。有什么建议吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-09-10 15:03:24

最简单的方法是使用np.eye()和numpy切片:

代码语言:javascript
复制
def one_hot_encode(y):
    """Do one-hot encoding of y
    Parameters
    ----------

    y : numpy array of arbitrary shape

    Returns one-hot-encoded y of the same shape plus one-hot-encoded vector as
    a last axis
    """

    # map `y' to an index value (from 0 to number of classes minus one)
    y_vals = sorted(np.unique(y))
    K = len(y_vals)
    to_index = np.vectorize(lambda x: y_vals.index(x))
    y = to_index(y)

    # remove the last dimension since we want to substitute it with a one-hot-vector
    if y.shape[-1] == 1 and len(y.shape>1):
        y = y.reshape(y.shape[:-1])

    # do one hot encoding:
    y = np.eye(K)[y].astype( np.uint8 if K<255 else np.uint16 )
    return y
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52235135

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档