首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflowsharp和Retinanet --如何确定在运行图时获取什么?

Tensorflowsharp和Retinanet --如何确定在运行图时获取什么?
EN

Stack Overflow用户
提问于 2018-10-22 23:29:53
回答 1查看 442关注 0票数 0

我已经成功地将TensorflowSharp与更快的RCNN一起使用了一段时间;然而,我最近训练了Retinanet模型,验证了它在python中的工作,并创建了一个冻结的pb文件供Tensorflow使用。对于FRCNN,在TensorflowSharp GitHub回购中有一个示例展示了如何运行/获取该模型。对于Retinanet,我试着修改代码,但似乎没有任何效果。我有一个关于Retinanet的模型摘要,我已经尝试过了,但是对于我来说还不清楚应该使用什么。

对于FRCNN,图形是这样运行的:

代码语言:javascript
复制
    var runner = m_session.GetRunner();

    runner
        .AddInput(m_graph["image_tensor"][0], tensor)
        .Fetch(
        m_graph["detection_boxes"][0],
        m_graph["detection_scores"][0],
        m_graph["detection_classes"][0],
        m_graph["num_detections"][0]);

       var output = runner.Run();

        var boxes = (float[,,])output[0].GetValue(jagged: false);
        var scores = (float[,])output[1].GetValue(jagged: false);
        var classes = (float[,])output[2].GetValue(jagged: false);
        var num = (float[])output[3].GetValue(jagged: false);

从FRCNN的模型总结来看,很明显输入("image_tensor")和输出("detection_boxes“、"detection_scores”、"detection_classes“和"num_detections")是什么。对于Retinanet来说,它们是不一样的(我试过了),我不知道它们应该是什么。上面代码的"Fetch“部分会导致崩溃,我猜这是因为我没有正确地使用节点名称。

我不会在这里粘贴整个Retinanet摘要,但下面是前几个节点:

代码语言:javascript
复制
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, None, None, 3 0                                            
__________________________________________________________________________________________________
padding_conv1 (ZeroPadding2D)   (None, None, None, 3 0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, None, None, 6 9408        padding_conv1[0][0]              
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, None, None, 6 256         conv1[0][0]                      
__________________________________________________________________________________________________
conv1_relu (Activation)         (None, None, None, 6 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________

下面是最后几个节点:

代码语言:javascript
复制
__________________________________________________________________________________________________
anchors_0 (Anchors)             (None, None, 4)      0           P3[0][0]                         
__________________________________________________________________________________________________
anchors_1 (Anchors)             (None, None, 4)      0           P4[0][0]                         
__________________________________________________________________________________________________
anchors_2 (Anchors)             (None, None, 4)      0           P5[0][0]                         
__________________________________________________________________________________________________
anchors_3 (Anchors)             (None, None, 4)      0           P6[0][0]                         
__________________________________________________________________________________________________
anchors_4 (Anchors)             (None, None, 4)      0           P7[0][0]                         
__________________________________________________________________________________________________
regression_submodel (Model)     (None, None, 4)      2443300     P3[0][0]                         
                                                                 P4[0][0]                         
                                                                 P5[0][0]                         
                                                                 P6[0][0]                         
                                                                 P7[0][0]                         
__________________________________________________________________________________________________
anchors (Concatenate)           (None, None, 4)      0           anchors_0[0][0]                  
                                                                 anchors_1[0][0]                  
                                                                 anchors_2[0][0]                  
                                                                 anchors_3[0][0]                  
                                                                 anchors_4[0][0]                  
__________________________________________________________________________________________________
regression (Concatenate)        (None, None, 4)      0           regression_submodel[1][0]        
                                                                 regression_submodel[2][0]        
                                                                 regression_submodel[3][0]        
                                                                 regression_submodel[4][0]        
                                                                 regression_submodel[5][0]        
__________________________________________________________________________________________________
boxes (RegressBoxes)            (None, None, 4)      0           anchors[0][0]                    
                                                                 regression[0][0]                 
__________________________________________________________________________________________________
classification_submodel (Model) (None, None, 1)      2381065     P3[0][0]                         
                                                                 P4[0][0]                         
                                                                 P5[0][0]                         
                                                                 P6[0][0]                         
                                                                 P7[0][0]                         
__________________________________________________________________________________________________
clipped_boxes (ClipBoxes)       (None, None, 4)      0           input_1[0][0]                    
                                                                 boxes[0][0]                      
__________________________________________________________________________________________________
classification (Concatenate)    (None, None, 1)      0           classification_submodel[1][0]    
                                                                 classification_submodel[2][0]    
                                                                 classification_submodel[3][0]    
                                                                 classification_submodel[4][0]    
                                                                 classification_submodel[5][0]    
__________________________________________________________________________________________________
filtered_detections (FilterDete [(None, 300, 4), (No 0           clipped_boxes[0][0]              
                                                                 classification[0][0]             
==================================================================================================
Total params: 36,382,957
Trainable params: 36,276,717
Non-trainable params: 106,240

任何帮助,以了解如何修复“获取”部分,这将是非常感谢。

编辑:

为了进一步了解这一点,我找到了一个python函数,用于从.pb文件中打印操作名。在为FRCNN .pb文件执行此操作时,它清楚地给出了输出节点的名称,如下所示(只发布python函数输出的最后几行)。

代码语言:javascript
复制
import/SecondStagePostprocessor/BatchMultiClassNonMaxSuppression/map/TensorArrayStack_4/TensorArrayGatherV3
import/SecondStagePostprocessor/ToFloat_1
import/add/y
import/add
import/detection_boxes
import/detection_scores
import/detection_classes
import/num_detections

如果我对Retinanet .pb文件做同样的事情,那么输出是什么就不太明显了。下面是python函数的最后几行代码。

代码语言:javascript
复制
import/filtered_detections/map/while/NextIteration_4
import/filtered_detections/map/while/Exit_2
import/filtered_detections/map/while/Exit_3
import/filtered_detections/map/while/Exit_4
import/filtered_detections/map/TensorArrayStack/TensorArraySizeV3
import/filtered_detections/map/TensorArrayStack/range/start
import/filtered_detections/map/TensorArrayStack/range/delta
import/filtered_detections/map/TensorArrayStack/range
import/filtered_detections/map/TensorArrayStack/TensorArrayGatherV3
import/filtered_detections/map/TensorArrayStack_1/TensorArraySizeV3
import/filtered_detections/map/TensorArrayStack_1/range/start
import/filtered_detections/map/TensorArrayStack_1/range/delta
import/filtered_detections/map/TensorArrayStack_1/range
import/filtered_detections/map/TensorArrayStack_1/TensorArrayGatherV3
import/filtered_detections/map/TensorArrayStack_2/TensorArraySizeV3
import/filtered_detections/map/TensorArrayStack_2/range/start
import/filtered_detections/map/TensorArrayStack_2/range/delta
import/filtered_detections/map/TensorArrayStack_2/range
import/filtered_detections/map/TensorArrayStack_2/TensorArrayGatherV3

作为参考,下面是我使用的python函数:

代码语言:javascript
复制
def printTensors(pb_file):

    # read pb into graph_def
    with tf.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # import graph_def
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

    # print operations
    for op in graph.get_operations():
        print(op.name)

希望这能有所帮助。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-03-04 13:09:16

我不太清楚你所面临的问题;你可以从TF服务输出中得到输出,实际上,在视障Ipython/木星笔记本中,他们也提到了输出格式。

查询保存模型

代码语言:javascript
复制
  """  The given SavedModel SignatureDef contains the following output(s):
    outputs['filtered_detections/map/TensorArrayStack/TensorArrayGatherV3:0'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 300, 4)
        name: filtered_detections/map/TensorArrayStack/TensorArrayGatherV3:0
    outputs['filtered_detections/map/TensorArrayStack_1/TensorArrayGatherV3:0'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 300)
        name: filtered_detections/map/TensorArrayStack_1/TensorArrayGatherV3:0
    outputs['filtered_detections/map/TensorArrayStack_2/TensorArrayGatherV3:0'] tensor_info:
        dtype: DT_INT32
        shape: (-1, 300)
        name: filtered_detections/map/TensorArrayStack_2/TensorArrayGatherV3:0
  Method name is: tensorflow/serving/predict
  ---
  From retina-net
  In general, inference of the network works as follows:
  boxes, scores, labels = model.predict_on_batch(inputs)
  Where `boxes` are shaped `(None, None, 4)` (for `(x1, y1, x2, y2)`), scores is shaped `(None, None)` (classification score) and labels is shaped `(None, None)` (label corresponding to the score). In all three outputs, the first dimension represents the shape and the second dimension indexes the list of detections.
"""
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52939042

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档