首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用Tensorflow预训练模型

使用Tensorflow预训练模型
EN

Stack Overflow用户
提问于 2020-03-09 09:10:04
回答 1查看 381关注 0票数 0

我正在尝试使用tensorflow预先训练过的权重模型

我有点不知所措,我应该如何加载它来生成预测。我想用faster_rcnn模型在图像上进行物体检测。

对于模型faster_rcnn_inception_resnet_v2_atrous_lowproposals_oid_2018_01_28,我有以下文件:

代码语言:javascript
复制
|   checkpoint
|   frozen_inference_graph.pb
|   model.ckpt.data-00000-of-00001
|   model.ckpt.index
|   model.ckpt.meta
|   pipeline.config
|
\---saved_model
    |   saved_model.pb
    |
    \---variables

下面是我加载模型并生成一些预测的尝试:

代码语言:javascript
复制
import tensorflow as tf
import cv2

model_folder = "faster_rcnn_inception_resnet_v2_atrous_lowproposals_oid_2018_01_28"

model_graph_file = model_folder+"/frozen_inference_graph.pb"
model_weights_file = model_folder+"/model.ckpt.data-00000-of-00001"

graph_def = tf.GraphDef()
graph_def.ParseFromString(tf.gfile.Open(model_graph_file,'rb').read())

#print([n.name + '=>' +  n.op for n in graph_def.node if n.op in ('Placeholder')])
#print([n.name + '=>' +  n.op for n in graph_def.node if n.op in ('Softmax')])

input = graph.get_tensor_by_name('image_tensor:0')
classes = graph.get_tensor_by_name('detection_classes:0')
scores = graph.get_tensor_by_name('detection_scores:0')
boxes = graph.get_tensor_by_name('detection_boxes:0')
softmax = graph.get_tensor_by_name('Softmax:0')

my_image = cv2.imread('resources/my_image.jpg')

with tf.Session(graph=graph) as sess:
    classes_out,scores_out,boxes_out,softmax  = sess.run([classes,scores,boxes,softmax],feed_dict={input:[my_image]})
    print(classes_out)
    print(classes_out.shape)
    print(scores_out)
    print(scores_out.shape)
    print(boxes_out)
    print(boxes_out.shape)
    print(softmax)
    print(softmax.shape)

其中打印如下:

代码语言:javascript
复制
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
(1, 20)
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
(1, 20)
[[[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]]
(1, 20, 4)
[[9.9819970e-01 1.8002436e-03]
 [9.9932957e-01 6.7051285e-04]
 [9.9853170e-01 1.4682930e-03]
 ...
 [9.9990737e-01 9.2630769e-05]
 [9.9939859e-01 6.0135941e-04]
 [9.6443009e-01 3.5569914e-02]]
(115200, 2)

很明显,我在这里做错了什么,但我不知道具体是什么。如何知道要使用哪些层作为输出层?如何检索对象的类、分数和框?我的模型正确吗?

编辑:

根据Lescurel的答复:

由于某些原因,我不得不对代码进行一些更改才能运行它:tf.saved_model.tag_constants.SERVING -> [tf.saved_model.tag_constants.SERVING]

input_tensor = model_signature["inputs"].name -> input_tensor = model_signature.inputs['inputs'].name.(使用tensorflow 1.12)

现在我有了一些结果,我对此感到非常高兴,但是对于Lescurel使用的相同的图像和模型,我有非常不同的输出:

代码语言:javascript
复制
[array([[0.5936514 , 0.5774365 , 0.519677  , 0.46745843, 0.36366013,
        0.3496253 , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ]],
      dtype=float32), array([[33.,  1., 68., 11., 13.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.]], dtype=float32), array([6.], dtype=float32), array([[[0.6699049 , 0.68924683, 0.9372702 , 0.78685343],
        [0.21414267, 0.264757  , 0.9868771 , 0.51174635],
        [0.34444967, 0.65146637, 0.70101655, 0.80124986],
        [0.8743748 , 0.7071637 , 0.9687472 , 0.7784833 ],
        [0.7832241 , 0.51456743, 0.9550611 , 0.59617543],
        [0.32543942, 0.6407225 , 0.9539846 , 0.81454873],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]]], dtype=float32)]

知道为什么吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-03-09 10:12:35

您加载了网络的图形结构,但没有加载经过训练的权重。正因为如此,网络无法进行任何有意义的预测。要在Tf1.x中加载图的权重,可以参考指南

下面的代码片段加载图形及其权重,并执行预测(此代码段使用来自faster_rcnn_inception_resnet_v2_atrous_lowproposals_coco模范动物园)

代码语言:javascript
复制
import cv2
import tensorflow as tf #tf.1.x

model_dir = "faster_rcnn_inception_resnet_v2_atrous_lowproposals_coco_2018_01_28/saved_model"

img = cv2.imread("/path/to/image.jpg")

with tf.Session() as sess:
    # We load the model and its weights
    # Models from the zoo are frozen, so we use the SERVING tag
    model = tf.saved_model.loader.load(sess, 
                               tf.saved_model.tag_constants.SERVING, 
                               model_dir)
    # we get the model signature
    model_signature = model.signature_def["serving_default"]
    input_tensor = model_signature["inputs"].name
    # getting the name of the outputs
    output_tensor = [v.name for k,v in model_signature.outputs.items() if v.name]
    # running the prediction
    outs = sess.run(output_tensor, feed_dict={input_tensor:[img]})

图像上的样本输出:

代码语言:javascript
复制
>>> outs
[array([[0.9998708 , 0.99963164, 0.9926651 , 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        , 0.        , 0.        , 0.        ]],
       dtype=float32),
 array([[ 1.,  1., 18.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
          1.,  1.,  1.,  1.,  1.,  1.,  1.]], dtype=float32),
 array([3.], dtype=float32),
 array([[[0.35335696, 0.6397857 , 0.96252066, 0.8067749 ],
         [0.25126144, 0.2766906 , 0.97366196, 0.5463176 ],
         [0.7696026 , 0.52089834, 0.9537483 , 0.59052485],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ]]], dtype=float32)]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60597454

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档