首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >categorical_accuracy和sparse_categorical_accuracy的不同之处

categorical_accuracy和sparse_categorical_accuracy的不同之处
EN

Stack Overflow用户
提问于 2017-06-10 19:55:10
回答 4查看 54.9K关注 0票数 69

在Keras中,categorical_accuracysparse_categorical_accuracy有什么区别?在这些指标的文档中没有任何提示,通过询问谷歌博士,我也没有找到答案。

源代码可以找到这里

代码语言:javascript
复制
def categorical_accuracy(y_true, y_pred):
    return K.cast(K.equal(K.argmax(y_true, axis=-1),
                          K.argmax(y_pred, axis=-1)),
                  K.floatx())


def sparse_categorical_accuracy(y_true, y_pred):
    return K.cast(K.equal(K.max(y_true, axis=-1),
                          K.cast(K.argmax(y_pred, axis=-1), K.floatx())),
                  K.floatx())
EN

回答 4

Stack Overflow用户

回答已采纳

发布于 2017-06-11 12:54:47

看着来源

代码语言:javascript
复制
def categorical_accuracy(y_true, y_pred):
    return K.cast(K.equal(K.argmax(y_true, axis=-1),
                          K.argmax(y_pred, axis=-1)),
                  K.floatx())


def sparse_categorical_accuracy(y_true, y_pred):
    return K.cast(K.equal(K.max(y_true, axis=-1),
                          K.cast(K.argmax(y_pred, axis=-1), K.floatx())),
K.floatx())

categorical_accuracy检查最大真值的指数是否等于最大预测值的指数。

sparse_categorical_accuracy检查最大真值是否等于最大预测值的指数。

根据Marcin上面的答案,categorical_accuracy对应于one-hot编码的y_true载体。

票数 53
EN

Stack Overflow用户

发布于 2017-06-11 12:42:45

因此,在categorical_accuracy中,您需要将目标(y)指定为一个热编码向量(例如,在3个类的情况下,当一个真正的类是第二类时,y应该是(0, 1, 0)。在sparse_categorical_accuracy中,您只需要提供真正类的整数(在前面的示例中--因为类索引是0-based),这将是1

票数 88
EN

Stack Overflow用户

发布于 2020-08-17 20:14:33

sparse_categorical_accuracy期望稀疏目标

代码语言:javascript
复制
[[0], [1], [2]]

例如:

代码语言:javascript
复制
import tensorflow as tf

sparse = [[0], [1], [2]]
logits = [[.8, .1, .1], [.5, .3, .2], [.2, .2, .6]]

sparse_cat_acc = tf.metrics.SparseCategoricalAccuracy()
sparse_cat_acc(sparse, logits)
代码语言:javascript
复制
<tf.Tensor: shape=(), dtype=float64, numpy=0.6666666666666666>

categorical_accuracy期望一个热编码目标

代码语言:javascript
复制
[[1., 0., 0.],  [0., 1., 0.], [0., 0., 1.]]

例如:

代码语言:javascript
复制
onehot = [[1., 0., 0.],  [0., 1., 0.], [0., 0., 1.]]
logits = [[.8, .1, .1], [.5, .3, .2], [.2, .2, .6]]

cat_acc = tf.metrics.CategoricalAccuracy()
cat_acc(sparse, logits)
代码语言:javascript
复制
<tf.Tensor: shape=(), dtype=float64, numpy=0.6666666666666666>
票数 17
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/44477489

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档