首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow:使用MNIST的InvalidArgumentError,[55000]与[10000]

Tensorflow:使用MNIST的InvalidArgumentError,[55000]与[10000]
EN

Stack Overflow用户
提问于 2017-01-13 10:38:57
回答 1查看 938关注 0票数 0

我正在处理此演示文稿http://www.youtube.com/watch?v=vq2nnJ4g6N0&t=20m28s中的代码,并收到以下错误:InvalidArgumentError (see above for traceback): Incompatible shapes: [55000] vs. [10000]

我已经解决了一些关于张量形状/尺寸的错误,但不知道如何具体理解它,更不用说纠正它了。

我是tf的新手,任何建议都非常感谢,以下是代码:

代码语言:javascript
复制
# 1 ~ import tf + data
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

#2 ~ paras + init
X = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

init = tf.initialize_all_variables()

#3 ~ model + correct answers
Y = tf.nn.softmax(tf.matmul(X, W) + b)
Y_ = tf.placeholder(tf.float32, [None, 10]) # one-hot encoding

#4 ~ loss function
cross_entropy = -tf.reduce_sum(Y_ * tf.log(Y))

#5 ~ correct answer + % accuracy
is_correct = tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)) # one-hot decoding
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

#6 ~ optimizer and training step
optimizer = tf.train.GradientDescentOptimizer(0.003) # learning-rate
train_step = optimizer.minimize(cross_entropy)

#7 ~ session and training loop
sess = tf.Session()
sess.run(init)

for i in range(1000):
    # load a batch of images and correct answers
    batch_X, batch_Y = mnist.train.next_batch(100)
    train_data = {X: batch_X, Y_: batch_Y}

    # train
    sess.run(train_step, feed_dict=train_data)

    # success?
    a,c = sess.run([accuracy, cross_entropy], feed_dict=train_data)

    # success on test data?
    test_data = {X: mnist.train.images, Y_: mnist.test.labels}
    a,c = sess.run([accuracy, cross_entropy], feed_dict=test_data)

总错误输出为:

代码语言:javascript
复制
Traceback (most recent call last):
      File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1021, in _do_call
    return fn(*args)
  File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1003, in _run_fn
    status, run_metadata)
  File "/Users/joelmcleod/anaconda3/lib/python3.5/contextlib.py", line 66, in __exit__
    next(self.gen)
  File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 469, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [55000] vs. [10000]
     [[Node: Equal = Equal[T=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"](ArgMax, ArgMax_1)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "deep_nn1.py", line 71, in <module>
    a,c = sess.run([accuracy, cross_entropy], feed_dict=test_data)
  File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 766, in run
    run_metadata_ptr)
  File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 964, in _run
    feed_dict_string, options, run_metadata)
  File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1014, in _do_run
    target_list, options, run_metadata)
  File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1034, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [55000] vs. [10000]
     [[Node: Equal = Equal[T=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"](ArgMax, ArgMax_1)]]

Caused by op 'Equal', defined at:
  File "deep_nn1.py", line 47, in <module>
    is_correct = tf.equal(tf.argmax(Y, 1), tf.argmax(Y_, 1)) # one-hot decoding
  File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/gen_math_ops.py", line 728, in equal
    result = _op_def_lib.apply_op("Equal", x=x, y=y, name=name)
  File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 759, in apply_op
    op_def=op_def)
  File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2240, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/Users/joelmcleod/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1128, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Incompatible shapes: [55000] vs. [10000]
     [[Node: Equal = Equal[T=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"](ArgMax, ArgMax_1)]]
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-01-17 08:54:27

您输入的是训练数据(行数: 55,000),而不是测试数据(行数: 10,000):

代码语言:javascript
复制
test_data = {X: mnist.train.images, Y_: mnist.test.labels}

只需使用以下命令进行修复:

代码语言:javascript
复制
test_data = {X: mnist.test.images, Y_: mnist.test.labels}
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41626542

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档