首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Tensorflow对象检测API中使用两种模型

如何在Tensorflow对象检测API中使用两种模型
EN

Stack Overflow用户
提问于 2018-05-16 06:58:47
回答 2查看 2.9K关注 0票数 7

tensorflow对象检测API中,我们使用ssd_mobilenet_v1_coco_2017_11_17模型来检测90个通用对象。我想用这个模型来检测。接下来,我训练了faster_rcnn_inception_v2_coco_2018_01_28模型来检测自定义对象。我希望在相同的代码中使用这一点,在这段代码中,我将能够检测到这90个对象以及我的新培训的自定义对象。如何用单个代码来实现这一点?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-06-02 05:38:22

我是通过在detect_object.py中执行以下代码来实现这一点的

代码语言:javascript
复制
import numpy as np
import tensorflow as tf
import sys
from PIL import Image
import cv2

from utils import label_map_util
from utils import visualization_utils as vis_util

# ------------------ Knife Model Initialization ------------------------------ #
knife_label_map = label_map_util.load_labelmap('training/labelmap.pbtxt')
knife_categories = label_map_util.convert_label_map_to_categories(
    knife_label_map, max_num_classes=1, use_display_name=True)
knife_category_index = label_map_util.create_category_index(knife_categories)

knife_detection_graph = tf.Graph()

with knife_detection_graph.as_default():
    knife_od_graph_def = tf.GraphDef()
    with tf.gfile.GFile('inference_graph_3/frozen_inference_graph.pb', 'rb') as fid:
        knife_serialized_graph = fid.read()
        knife_od_graph_def.ParseFromString(knife_serialized_graph)
        tf.import_graph_def(knife_od_graph_def, name='')

    knife_session = tf.Session(graph=knife_detection_graph)

knife_image_tensor = knife_detection_graph.get_tensor_by_name('image_tensor:0')
knife_detection_boxes = knife_detection_graph.get_tensor_by_name(
    'detection_boxes:0')
knife_detection_scores = knife_detection_graph.get_tensor_by_name(
    'detection_scores:0')
knife_detection_classes = knife_detection_graph.get_tensor_by_name(
    'detection_classes:0')
knife_num_detections = knife_detection_graph.get_tensor_by_name(
    'num_detections:0')
# ---------------------------------------------------------------------------- #

# ------------------ General Model Initialization ---------------------------- #
general_label_map = label_map_util.load_labelmap('data/mscoco_label_map.pbtxt')
general_categories = label_map_util.convert_label_map_to_categories(
    general_label_map, max_num_classes=90, use_display_name=True)
general_category_index = label_map_util.create_category_index(
    general_categories)

general_detection_graph = tf.Graph()

with general_detection_graph.as_default():
    general_od_graph_def = tf.GraphDef()
    with tf.gfile.GFile('ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb', 'rb') as fid:
        general_serialized_graph = fid.read()
        general_od_graph_def.ParseFromString(general_serialized_graph)
        tf.import_graph_def(general_od_graph_def, name='')

    general_session = tf.Session(graph=general_detection_graph)

general_image_tensor = general_detection_graph.get_tensor_by_name(
    'image_tensor:0')
general_detection_boxes = general_detection_graph.get_tensor_by_name(
    'detection_boxes:0')
general_detection_scores = general_detection_graph.get_tensor_by_name(
    'detection_scores:0')
general_detection_classes = general_detection_graph.get_tensor_by_name(
    'detection_classes:0')
general_num_detections = general_detection_graph.get_tensor_by_name(
    'num_detections:0')
# ---------------------------------------------------------------------------- #


def knife(image_path):
    try:
        image = cv2.imread(image_path)
        image_expanded = np.expand_dims(image, axis=0)
        (boxes, scores, classes, num) = knife_session.run(
            [knife_detection_boxes, knife_detection_scores,
                knife_detection_classes, knife_num_detections],
            feed_dict={knife_image_tensor: image_expanded})

        classes = np.squeeze(classes).astype(np.int32)
        scores = np.squeeze(scores)
        boxes = np.squeeze(boxes)

        for c in range(0, len(classes)):
            class_name = knife_category_index[classes[c]]['name']
            if class_name == 'knife' and scores[c] > .80:
                confidence = scores[c] * 100
                break
            else:
                confidence = 0.00
    except:
        print("Error occurred in knife detection")
        confidence = 0.0   # Some error has occurred
    return confidence


def general(image_path):
    try:
        image = cv2.imread(image_path)
        image_expanded = np.expand_dims(image, axis=0)
        (boxes, scores, classes, num) = general_session.run(
            [general_detection_boxes, general_detection_scores,
                general_detection_classes, general_num_detections],
            feed_dict={general_image_tensor: image_expanded})

        classes = np.squeeze(classes).astype(np.int32)
        scores = np.squeeze(scores)
        boxes = np.squeeze(boxes)

        object_name = []
        object_score = []

        for c in range(0, len(classes)):
            class_name = general_category_index[classes[c]]['name']
            if scores[c] > .30:   # If confidence level is good enough
                object_name.append(class_name)
                object_score.append(str(scores[c] * 100)[:5])
    except:
        print("Error occurred in general detection")
        object_name = ['']
        object_score = ['']

    return object_name, object_score


if __name__ == '__main__':
    print(' in main')

我能做到

代码语言:javascript
复制
import detect_object
detect_object.knife("image.jpg")  # to detect whether knife is present in image(this is custom trained model)

detect_object.general("image.jpg")  # to detect those 90 objects from TF API

我知道在TF API中有刀模型,但它并不是那么精确,所以我重新训练它只是为了刀。最后,我有两个模型1.第一种模型是只检测刀子,第二种模型是像往常一样检测一般对象。

票数 3
EN

Stack Overflow用户

发布于 2018-05-31 20:55:14

你不能把这两种模式结合起来。有两段代码,一次加载一个模型,然后识别它能在图像中看到的内容。另一种选择是重新训练单个模型,以识别您感兴趣的所有对象。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50364281

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档