我有一个标量的NxTxW张量;我正在尝试找出如何使用numpy将其转换为one-hot编码。有什么建议吗?
发布于 2018-09-10 15:03:24
最简单的方法是使用np.eye()和numpy切片:
def one_hot_encode(y):
"""Do one-hot encoding of y
Parameters
----------
y : numpy array of arbitrary shape
Returns one-hot-encoded y of the same shape plus one-hot-encoded vector as
a last axis
"""
# map `y' to an index value (from 0 to number of classes minus one)
y_vals = sorted(np.unique(y))
K = len(y_vals)
to_index = np.vectorize(lambda x: y_vals.index(x))
y = to_index(y)
# remove the last dimension since we want to substitute it with a one-hot-vector
if y.shape[-1] == 1 and len(y.shape>1):
y = y.reshape(y.shape[:-1])
# do one hot encoding:
y = np.eye(K)[y].astype( np.uint8 if K<255 else np.uint16 )
return yhttps://stackoverflow.com/questions/52235135
复制相似问题