首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow Inception resnet v2输入张量

Tensorflow Inception resnet v2输入张量
EN

Stack Overflow用户
提问于 2016-09-28 22:35:16
回答 2查看 3.2K关注 0票数 0

我正在尝试运行这段代码

代码语言:javascript
复制
import os
import tensorflow as tf
from datasets import imagenet
from nets import inception_resnet_v2
from preprocessing import inception_preprocessing

checkpoints_dir = 'model'

slim = tf.contrib.slim

batch_size = 3
image_size = 299

with tf.Graph().as_default():

with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
    logits, _ = inception_resnet_v2.inception_resnet_v2([1, 299, 299, 3], num_classes=1001, is_training=False)
    probabilities = tf.nn.softmax(logits)

    init_fn = slim.assign_from_checkpoint_fn(
    os.path.join(checkpoints_dir, 'inception_resnet_v2_2016_08_30.ckpt'),
    slim.get_model_variables('InceptionResnetV2'))

    with tf.Session() as sess:
        init_fn(sess)

        imgPath = '.../image_3.jpeg'
        testImage_string = tf.gfile.FastGFile(imgPath, 'rb').read()
        testImage = tf.image.decode_jpeg(testImage_string, channels=3)

        np_image, probabilities = sess.run([testImage, probabilities])
        probabilities = probabilities[0, 0:]
        sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x:x[1])]

        names = imagenet.create_readable_names_for_imagenet_labels()
        for i in range(15):
            index = sorted_inds[i]
            print((probabilities[index], names[index]))

但是TF显示一个错误:ValueError: rank of shape must be at least 4 not: 1

我认为问题出在输入张量形状[1, 299, 299, 3]中。3通道JPEG图像如何输入张量?

还有一个类似的问题(Using pre-trained inception_resnet_v2 with Tensorflow)。我在代码input_tensor中看到了--不幸的是,有关于什么是input_tensor的解释。也许我是在问一些不言而喻的问题,但我坚持了下来!在此之前,非常感谢您的建议!

EN

回答 2

Stack Overflow用户

发布于 2017-01-14 18:40:05

你必须对你的图像进行预处理。下面是一个代码:

代码语言:javascript
复制
import os
import tensorflow as tf
from datasets import imagenet
from nets import inception_resnet_v2
from preprocessing import inception_preprocessing

checkpoints_dir = 'model'

slim = tf.contrib.slim

batch_size = 3
image_size = 299

with tf.Graph().as_default():
    with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):

        imgPath = '.../cat.jpg'
        testImage_string = tf.gfile.FastGFile(imgPath, 'rb').read()
        testImage = tf.image.decode_jpeg(testImage_string, channels=3)
        processed_image = inception_preprocessing.preprocess_image(testImage, image_size, image_size, is_training=False)
        processed_images = tf.expand_dims(processed_image, 0)

        logits, _ = inception_resnet_v2.inception_resnet_v2(processed_images, num_classes=1001, is_training=False)
        probabilities = tf.nn.softmax(logits)

        init_fn = slim.assign_from_checkpoint_fn(
        os.path.join(checkpoints_dir, 'inception_resnet_v2_2016_08_30.ckpt'), slim.get_model_variables('InceptionResnetV2'))

        with tf.Session() as sess:
            init_fn(sess)

            np_image, probabilities = sess.run([processed_images, probabilities])
            probabilities = probabilities[0, 0:]
            sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x: x[1])]

            names = imagenet.create_readable_names_for_imagenet_labels()
            for i in range(15):
                index = sorted_inds[i]
                print((probabilities[index], names[index]))

答案是:

代码语言:javascript
复制
(0.1131034, 'tiger cat')
(0.079478227, 'tabby, tabby cat')
(0.052777905, 'Cardigan, Cardigan Welsh corgi')
(0.030195976, 'laptop, laptop computer')
(0.027841948, 'bathtub, bathing tub, bath, tub')
(0.026694898, 'television, television system')
(0.024981709, 'carton')
(0.024039172, 'Egyptian cat')
(0.018425584, 'tub, vat')
(0.018221909, 'Pembroke, Pembroke Welsh corgi')
(0.015066789, 'skunk, polecat, wood pussy')
(0.01377619, 'screen, CRT screen')
(0.012509955, 'monitor')
(0.012224807, 'mouse, computer mouse')
(0.012188354, 'refrigerator, icebox')
票数 3
EN

Stack Overflow用户

发布于 2017-05-04 21:43:02

您可以使用tf.expand_dims(your_tensor_3channel, axis=0)将其扩展为批处理格式。

票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/39750572

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档