首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow中的单图像推理[Python]

Tensorflow中的单图像推理[Python]
EN

Stack Overflow用户
提问于 2017-08-16 01:04:19
回答 2查看 8.8K关注 0票数 5

我已经将一个预先训练好的.ckpt文件转换成了.pb文件,冻结了模型并保存了权重。我现在尝试做的是使用该.pb文件做一个简单的推断,并提取并保存输出图像。该模型是从这里下载的(用于语义分割的完全卷积网络):https://github.com/MarvinTeichmann/KittiSeg。到目前为止,我已经成功地加载了图像,设置了默认的tf图,并导入了模型在其上定义的图,读取了输入和输出张量,并运行了会话(错误在这里)。

代码语言:javascript
复制
import tensorflow as tf
import os
import numpy as np
from tensorflow.python.platform import gfile
from PIL import Image

# Read the image & get statstics
img=Image.open('/path-to-image/demoImage.png')
img.show()
width, height = img.size
print(width)
print(height)

#Plot the image
#image.show()

with tf.Graph().as_default() as graph:

        with tf.Session() as sess:

                # Load the graph in graph_def
                print("load graph")

                # We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
                with gfile.FastGFile("/path-to-FCN-model/FCN8.pb",'rb') as f:

                                #Set default graph as current graph
                                graph_def = tf.GraphDef()
                                graph_def.ParseFromString(f.read())
                                #sess.graph.as_default() #new line

                                # Import a graph_def into the current default Graph
                                tf.import_graph_def(graph_def, name='')

                                # Print the name of operations in the session
                                #for op in sess.graph.get_operations():

                                    #print "Operation Name :",op.name            # Operation name
                                    #print "Tensor Stats :",str(op.values())     # Tensor name

                                # INFERENCE Here
                                l_input = graph.get_tensor_by_name('Placeholder:0')
                                l_output = graph.get_tensor_by_name('save/Assign_38:0')

                                print "l_input", l_input
                                print "l_output", l_output
                                print
                                print

                                # Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.                              
                                result = sess.run(l_output, feed_dict={l_input : img})
                                print(results)

                                print("Inference done")

                                # Info
                                # First Tensor name : Placeholder:0
                                # Last tensor name  : save/Assign_38:0"

错误是否来自图像的格式(例如,我是否应该将.png转换为另一种格式?)。这是另一个根本性的错误吗?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-08-23 17:00:08

我设法修复了这个错误,下面是在完全卷积网络上推断单个图像的工作脚本(对于任何对SEGNET的替代分割算法感兴趣的人)。此模型使用双线性插值进行缩放,而不是非池化层。无论如何,因为模型可以.chkpt格式下载,所以您必须首先冻结模型并将其另存为.pb文件。稍后,您必须从TF优化器传递网络以将Dropout概率设置为1。然后,在此脚本中设置正确的输入和输出张量名称,推理就会正确工作,提取分割后的图像。

代码语言:javascript
复制
import tensorflow as tf # Default graph is initialized when the library is imported
import os
from tensorflow.python.platform import gfile
from PIL import Image
import numpy as np
import scipy
from scipy import misc
import matplotlib.pyplot as plt
import cv2

with tf.Graph().as_default() as graph: # Set default graph as graph

           with tf.Session() as sess:
                # Load the graph in graph_def
                print("load graph")

                # We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
                with gfile.FastGFile("/path-to-protobuf/FCN8_Freezed.pb",'rb') as f:

                                print("Load Image...")
                                # Read the image & get statstics
                                image = scipy.misc.imread('/Path-To-Image/uu_000010.png')
                                image = image.astype(float)
                                Input_image_shape=image.shape
                                height,width,channels = Input_image_shape

                                print("Plot image...")
                                #scipy.misc.imshow(image)

                                # Set FCN graph to the default graph
                                graph_def = tf.GraphDef()
                                graph_def.ParseFromString(f.read())
                                sess.graph.as_default()

                                # Import a graph_def into the current default Graph (In this case, the weights are (typically) embedded in the graph)

                                tf.import_graph_def(
                                graph_def,
                                input_map=None,
                                return_elements=None,
                                name="",
                                op_dict=None,
                                producer_op_list=None
                                )

                                # Print the name of operations in the session
                                for op in graph.get_operations():
                                        print "Operation Name :",op.name         # Operation name
                                        print "Tensor Stats :",str(op.values())     # Tensor name

                                # INFERENCE Here
                                l_input = graph.get_tensor_by_name('Inputs/fifo_queue_Dequeue:0') # Input Tensor
                                l_output = graph.get_tensor_by_name('upscore32/conv2d_transpose:0') # Output Tensor

                                print "Shape of input : ", tf.shape(l_input)
                                #initialize_all_variables
                                tf.global_variables_initializer()

                                # Run Kitty model on single image
                                Session_out = sess.run( l_output, feed_dict = {l_input : image} 
票数 4
EN

Stack Overflow用户

发布于 2017-08-16 03:26:08

你已经看过demo.py了吗。在141行显示了他们如何修改图形的输入:

代码语言:javascript
复制
# Create placeholder for input
image_pl = tf.placeholder(tf.float32)
image = tf.expand_dims(image_pl, 0)

# build Tensorflow graph using the model from logdir
prediction = core.build_inference_graph(hypes, modules,
                                        image=image)

164行,图像是如何打开的:

代码语言:javascript
复制
image = scp.misc.imread(input_image)

它被直接提供给image_pl。唯一的问题是core.build_inference_graph是一个TensorVision调用。

请注意,提供准确的错误消息以及输入也会很有趣。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45697823

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档