首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow和cifar 10,测试单个图像

Tensorflow和cifar 10,测试单个图像
EN

Stack Overflow用户
提问于 2016-10-26 15:26:24
回答 1查看 786关注 0票数 4

我试着用tensorflow的cifar-10预测单个图像的类别。

我找到了这段代码,但是由于这个错误,它失败了:

赋值要求两个张量的形状相匹配。lhs shape= 18,384 rhs shape= 2304,384我知道这是因为批次的大小,只有1。(用expand_dims创建一个假批。)

但我不知道怎么解决这个问题?

我到处找但没有解决办法。提前感谢!

代码语言:javascript
复制
from PIL import Image
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
width = 24
height = 24

categories =  ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck" ]

filename = "path/to/jpg" # absolute path to input image
im = Image.open(filename)
im.save(filename, format='JPEG', subsampling=0, quality=100)
input_img = tf.image.decode_jpeg(tf.read_file(filename), channels=3)
tf_cast = tf.cast(input_img, tf.float32)
float_image = tf.image.resize_image_with_crop_or_pad(tf_cast, height, width)
images = tf.expand_dims(float_image, 0)
logits = cifar10.inference(images)
_, top_k_pred = tf.nn.top_k(logits, k=5)
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('/tmp/cifar10_train')
if ckpt and ckpt.model_checkpoint_path:
    print("ckpt.model_checkpoint_path ", ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)
else:
    print('No checkpoint file found')
    exit(0)
sess.run(init_op)
_, top_indices = sess.run([_, top_k_pred])
for key, value in enumerate(top_indices[0]):
    print (categories[value] + ", " + str(_[0][key]))

编辑

我试着放置一个占位符,没有第一个形状,但是我得到了这个错误:必须完全定义一个新变量的形状(local3 3/weights),但是相反,是(?,384)。

现在我真的迷路了..。以下是新代码:

代码语言:javascript
复制
from PIL import Image
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
import itertools
width = 24
height = 24

categories = [ "airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck" ]

filename = "toto.jpg" # absolute path to input image
im = Image.open(filename)
im.save(filename, format='JPEG', subsampling=0, quality=100)
x = tf.placeholder(tf.float32, [None, 24, 24, 3])
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
 # Restore variables from training checkpoint.
    input_img = tf.image.decode_jpeg(tf.read_file(filename), channels=3)
    tf_cast = tf.cast(input_img, tf.float32)
    float_image = tf.image.resize_image_with_crop_or_pad(tf_cast, height, width)
    images = tf.expand_dims(float_image, 0)
    i = images.eval()
    print (i)
    sess.run(init_op, feed_dict={x: i})
    logits = cifar10.inference(x)
    _, top_k_pred = tf.nn.top_k(logits, k=5)
    variable_averages = tf.train.ExponentialMovingAverage(
        cifar10.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    ckpt = tf.train.get_checkpoint_state('/tmp/cifar10_train')
    if ckpt and ckpt.model_checkpoint_path:
        print("ckpt.model_checkpoint_path ", ckpt.model_checkpoint_path)
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        print('No checkpoint file found')
        exit(0)
    _, top_indices = sess.run([_, top_k_pred])
    for key, value in enumerate(top_indices[0]):
        print (categories[value] + ", " + str(_[0][key]))
EN

回答 1

Stack Overflow用户

发布于 2017-12-22 05:32:02

我认为这是因为由tf.Variabletf.get_variable获得的变量必须具有完全定义的形状。您可以检查您的代码并给出完整定义的形状。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/40266275

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档