首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow object_detection:无法找到输入和输出张量

Tensorflow object_detection:无法找到输入和输出张量
EN

Stack Overflow用户
提问于 2017-07-12 13:32:31
回答 2查看 6.7K关注 0票数 0

我使用tensorflow的对象检测API成功地训练并保存了一个更快的RCNN模型。现在,我正在尝试对代码进行一些推断,从本教程中获取一些代码。

但是,在成功地恢复了元计时器和检查点之后,系统无法找到输入和输出节点,我得到以下错误:

KeyError:“名称'image_tensor:0‘指的是不存在的张量。图中不存在’image_tensor:0_image_tensor:0‘的操作。”

检查点和元图是由train.py脚本在我自己的数据上创建的,遵循给出的这里指令。

这是我的密码:

代码语言:javascript
复制
OUTPUT_DIR = "my_path/models/SSD_v1/train"
CKPT_DIR = OUTPUT_DIR
LATEST_CKPT_FILENAME = "checkpoint"
LAST_CKPT_FILE = os.path.join(CKPT_DIR, LATEST_CKPT_FILENAME)
MODEL_FILENAME_PATH = os.path.join(OUTPUT_DIR, "model.ckpt.meta")
def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)


def test_model(images_list, path_to_ckpt=None,
               meta_graph=None):
    if path_to_ckpt is None:
        path_to_ckpt = tf.train.latest_checkpoint(CKPT_DIR, LATEST_CKPT_FILENAME)
    if meta_graph is None:
        meta_graph = MODEL_FILENAME_PATH
    print("test_model launched")

    tf.reset_default_graph()
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            # Restore graph
            saver = tf.train.import_meta_graph(meta_graph, clear_devices=True)
            print('metagraph restored')
            saver.restore(sess, path_to_ckpt)
            print('graph restored')

            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')  # This is where the error happens
            # Each box represents a part of the image where a particular object was detected.
            detected_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
            # Each score represent how level of confidence for each of the objects.
            # Score is shown on the result image, together with the class label.
            detected_scores = detection_graph.get_tensor_by_name('detection_scores:0')
            detected_classes = detection_graph.get_tensor_by_name('detection_classes:0')
            num_detections = graph.get_tensor_by_name('num_detections:0')

            print("Output tensors: ")
            print(detected_boxes)
            print(detected_scores)
            print(detected_classes)
            print('')

            for i, image in enumerate(images_list):
                detected_boxes, detected_scores, detected_classes, num_detect = sess.run([detected_boxes, detected_scores, detected_classes, num_detections],
                         feed_dict={image_tensor: image})
                print(i, num_detect, detected_boxes, detected_scores, detected_classes)


def main():
    directory_path = "../data/samples/"
    image_files = [f for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))]
    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
    image_list = [ np.expand_dims(load_image_into_numpy_array(Image.open(os.path.join(directory_path, f))), axis=0) for f in image_files]
    test_model(images_list=image_list)

if __name__=="__main__":
    main()

全错误堆栈跟踪:

代码语言:javascript
复制
Traceback (most recent call last):   File "/home/guillaumedelaboulaye/PR8210PANO/faster-rcnn/pano_faster_rcnn/src/run_faster_rcnn_inference.py", line 99, in <module>
    main()   File "/home/guillaumedelaboulaye/PR8210PANO/faster-rcnn/pano_faster_rcnn/src/run_faster_rcnn_inference.py", line 95, in main
    test_model(images_list=image_list)   File "/home/guillaumedelaboulaye/PR8210PANO/faster-rcnn/pano_faster_rcnn/src/run_faster_rcnn_inference.py", line 48, in test_model
    image_tensor = graph.get_tensor_by_name('image_tensor:0')   File "/home/guillaumedelaboulaye/PR8210PANO/faster-rcnn/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2733, in get_tensor_by_name
    return self.as_graph_element(name, allow_tensor=True, allow_operation=False)   File "/home/guillaumedelaboulaye/PR8210PANO/faster-rcnn/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2584, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)   File "/home/guillaumedelaboulaye/PR8210PANO/faster-rcnn/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2626, in _as_graph_element_locked
    "graph." % (repr(name), repr(op_name))) KeyError: "The name 'image_tensor:0' refers to a Tensor which does not exist. The operation, 'image_tensor', does not exist in the graph."
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-07-12 16:52:08

在列车图中,输入/输出节点没有给出这些名称。您需要做的是通过graph.py工具“导出”您经过培训的模型。我相信它目前将它导出到一个冻结的图形或一个SavedModel,但在以后的版本中,它也将导出到普通的检查点。

票数 2
EN

Stack Overflow用户

发布于 2017-12-20 23:39:00

如果您希望在“将一个(冻结的) Tensorflow模型加载到内存”之后,使用示例代码来查找图形的节点名称,请参考object_detection_tutorial.ipynb。区块:

Od_graph_def.node中的节点:打印node.name

它应该列出您随后可以输入的所有节点名称。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45059162

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档