首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf.nn.sparse_softmax_cross_entropy_with_logits秩误差

tf.nn.sparse_softmax_cross_entropy_with_logits秩误差
EN

Stack Overflow用户
提问于 2017-09-27 11:29:59
回答 1查看 1.9K关注 0票数 0

这是我的代码:

代码语言:javascript
复制
import tensorflow as tf
    with tf.Session() as sess:
        y = tf.constant([0,0,1])
        x = tf.constant([0,1,0])
        r = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)
        sess.run()
        print(r.eval())

它会产生以下错误:

代码语言:javascript
复制
ValueError                                Traceback (most recent call last)
<ipython-input-10-28a8854a9457> in <module>()
      4     y = tf.constant([0,0,1])
      5     x = tf.constant([0,1,0])
----> 6     r = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)
      7     sess.run()
      8     print(r.eval())

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\nn_ops.py in sparse_softmax_cross_entropy_with_logits(_sentinel, labels, logits, name)
   1687       raise ValueError("Rank mismatch: Rank of labels (received %s) should "
   1688                        "equal rank of logits minus 1 (received %s)." %
-> 1689                        (labels_static_shape.ndims, logits.get_shape().ndims))
   1690     # Check if no reshapes are required.
   1691     if logits.get_shape().ndims == 2:

ValueError: Rank mismatch: Rank of labels (received 1) should equal rank of logits minus 1 (received 1).

有人能帮我理解这个错误吗?如何手工计算软件最大值和交叉熵是非常直接的。

另外,我将如何使用这个函数,我需要将批处理输入到它(2个昏暗数组)?

更新

我也试过:

代码语言:javascript
复制
import tensorflow as tf

with tf.Session() as sess:
    y = tf.constant([1])
    x = tf.constant([0,1,0])
    r = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)
    sess.run()
    print(r.eval())

它产生了同样的错误

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-09-27 16:56:50

帮你修好了。x需要是一个二维向量

代码语言:javascript
复制
with tf.Session() as sess:
    y = tf.constant([1])
    x = tf.expand_dims(tf.constant([0.0, 1.0, 0.0]), 0)
    r = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)
    print(r.eval())
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46446763

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档