我已经将一个预先训练好的.ckpt文件转换成了.pb文件,冻结了模型并保存了权重。我现在尝试做的是使用该.pb文件做一个简单的推断,并提取并保存输出图像。该模型是从这里下载的(用于语义分割的完全卷积网络):https://github.com/MarvinTeichmann/KittiSeg。到目前为止,我已经成功地加载了图像,设置了默认的tf图,并导入了模型在其上定义的图,读取了输入和输出张量,并运行了会话(错误在这里)。
import tensorflow as tf
import os
import numpy as np
from tensorflow.python.platform import gfile
from PIL import Image
# Read the image & get statstics
img=Image.open('/path-to-image/demoImage.png')
img.show()
width, height = img.size
print(width)
print(height)
#Plot the image
#image.show()
with tf.Graph().as_default() as graph:
with tf.Session() as sess:
# Load the graph in graph_def
print("load graph")
# We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
with gfile.FastGFile("/path-to-FCN-model/FCN8.pb",'rb') as f:
#Set default graph as current graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
#sess.graph.as_default() #new line
# Import a graph_def into the current default Graph
tf.import_graph_def(graph_def, name='')
# Print the name of operations in the session
#for op in sess.graph.get_operations():
#print "Operation Name :",op.name # Operation name
#print "Tensor Stats :",str(op.values()) # Tensor name
# INFERENCE Here
l_input = graph.get_tensor_by_name('Placeholder:0')
l_output = graph.get_tensor_by_name('save/Assign_38:0')
print "l_input", l_input
print "l_output", l_output
print
print
# Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.
result = sess.run(l_output, feed_dict={l_input : img})
print(results)
print("Inference done")
# Info
# First Tensor name : Placeholder:0
# Last tensor name : save/Assign_38:0"错误是否来自图像的格式(例如,我是否应该将.png转换为另一种格式?)。这是另一个根本性的错误吗?
发布于 2017-08-23 17:00:08
我设法修复了这个错误,下面是在完全卷积网络上推断单个图像的工作脚本(对于任何对SEGNET的替代分割算法感兴趣的人)。此模型使用双线性插值进行缩放,而不是非池化层。无论如何,因为模型可以.chkpt格式下载,所以您必须首先冻结模型并将其另存为.pb文件。稍后,您必须从TF优化器传递网络以将Dropout概率设置为1。然后,在此脚本中设置正确的输入和输出张量名称,推理就会正确工作,提取分割后的图像。
import tensorflow as tf # Default graph is initialized when the library is imported
import os
from tensorflow.python.platform import gfile
from PIL import Image
import numpy as np
import scipy
from scipy import misc
import matplotlib.pyplot as plt
import cv2
with tf.Graph().as_default() as graph: # Set default graph as graph
with tf.Session() as sess:
# Load the graph in graph_def
print("load graph")
# We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
with gfile.FastGFile("/path-to-protobuf/FCN8_Freezed.pb",'rb') as f:
print("Load Image...")
# Read the image & get statstics
image = scipy.misc.imread('/Path-To-Image/uu_000010.png')
image = image.astype(float)
Input_image_shape=image.shape
height,width,channels = Input_image_shape
print("Plot image...")
#scipy.misc.imshow(image)
# Set FCN graph to the default graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
# Import a graph_def into the current default Graph (In this case, the weights are (typically) embedded in the graph)
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="",
op_dict=None,
producer_op_list=None
)
# Print the name of operations in the session
for op in graph.get_operations():
print "Operation Name :",op.name # Operation name
print "Tensor Stats :",str(op.values()) # Tensor name
# INFERENCE Here
l_input = graph.get_tensor_by_name('Inputs/fifo_queue_Dequeue:0') # Input Tensor
l_output = graph.get_tensor_by_name('upscore32/conv2d_transpose:0') # Output Tensor
print "Shape of input : ", tf.shape(l_input)
#initialize_all_variables
tf.global_variables_initializer()
# Run Kitty model on single image
Session_out = sess.run( l_output, feed_dict = {l_input : image} 发布于 2017-08-16 03:26:08
你已经看过demo.py了吗。在141行显示了他们如何修改图形的输入:
# Create placeholder for input
image_pl = tf.placeholder(tf.float32)
image = tf.expand_dims(image_pl, 0)
# build Tensorflow graph using the model from logdir
prediction = core.build_inference_graph(hypes, modules,
image=image)在164行,图像是如何打开的:
image = scp.misc.imread(input_image)它被直接提供给image_pl。唯一的问题是core.build_inference_graph是一个TensorVision调用。
请注意,提供准确的错误消息以及输入也会很有趣。
https://stackoverflow.com/questions/45697823
复制相似问题