首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用于TensorRT的插件

用于TensorRT的插件
EN

Stack Overflow用户
提问于 2019-08-08 22:27:06
回答 1查看 638关注 0票数 2

我在Tensorflow模型中使用了Tensorflow操作ResizeArea、Select、Fill和Equal。当model转换为uff时,收到警告:

代码语言:javascript
复制
Warning: No conversion function registered for layer: ResizeArea yet.
Converting upsample_heatmat as custom op: ResizeArea
Warning: No conversion function registered for layer: Select yet.
Converting Select as custom op: Select
Warning: No conversion function registered for layer: Fill yet.
Converting zeros_like as custom op: Fill
Warning: No conversion function registered for layer: Equal yet.
Converting Equal as custom op: Equal

所以插件是为ResizeArea、Select、Fill和Equal创建的。

然后将插件映射到Tensorflow操作

代码语言:javascript
复制
class ModelData(object):
    INPUT_NAME = "image"
    EQUAL_NAME = "Equal"
    SELECT_NAME = "Select"
    PMAT_NAME = "upsample_pafmat"
    ZERO_LIKE = "zeros_like"
    HMAT_NAME = "upsample_heatmat"
    OUTPUT_NAME = "Openpose/output"

def prepare_namespace_plugin_map():
    # In this sample, the only operation that is not supported by TensorRT
    # is tf.nn.relu6, so we create a new node which will tell UffParser which
    # plugin to run and with which arguments in place of tf.nn.relu6.


    # The "clipMin" and "clipMax" fields of this TensorFlow node will be parsed by createPlugin,
    # and used to create a CustomClipPlugin with the appropriate parameters.
    trt_resizearea = gs.create_plugin_node(name="trt_resizearea", op="ResizeAreaPlugin", in_width=80.0, in_height=60.0, in_channel=3.0, upscale=4.0)
    trt_fill = gs.create_plugin_node(name="trt_fill", op="FillPlugin", in_width=320.0, in_height=240.0, in_channel=3.0, value=0.0)#fill 0
    trt_equal = gs.create_plugin_node(name="trt_equal", op="EqualPlugin", in_width=320.0, in_height=240.0, in_channel=3.0)
    trt_select = gs.create_plugin_node(name="trt_select", op="SelectPlugin", in_width=320.0, in_height=240.0, value=0.0)
    namespace_plugin_map = {
        ModelData.SELECT_NAME: trt_select,
        ModelData.EQUAL_NAME: trt_equal,
        ModelData.PMAT_NAME: trt_resizearea,
        ModelData.HMAT_NAME: trt_resizearea,
        ModelData.ZERO_LIKE: trt_fill
    }
    return namespace_plugin_map

def model_to_uff(model_path):
    # Transform graph using graphsurgeon to map unsupported TensorFlow
    # operations to appropriate TensorRT custom layer plugins
    dynamic_graph = gs.DynamicGraph(model_path)
    dynamic_graph.collapse_namespaces(prepare_namespace_plugin_map())
    # Save resulting graph to UFF file
    output_uff_path = model_path_to_uff_path(model_path)
    uff.from_tensorflow(
        dynamic_graph.as_graph_def(),
        [ModelData.OUTPUT_NAME],
        output_filename=output_uff_path,
        text=True
    )
    return output_uff_path

def model_path_to_uff_path(model_path):
    uff_path = os.path.splitext(model_path)[0] + ".uff"
    return uff_path

为什么我仍然有警告,因为

代码语言:javascript
复制
Warning: No conversion function registered for layer: ResizeAreaPlugin yet.
Converting trt_resizearea as custom op: ResizeAreaPlugin
W0808 17:44:51.442725 139793630279424 deprecation_wrapper.py:119] From /home/coie/Data/coie/Softwares/venv/lib/python3.5/site-packages/uff/converters/tensorflow/converter.py:179: The name tf.AttrValue is deprecated. Please use tf.compat.v1.AttrValue instead.

Warning: No conversion function registered for layer: SelectPlugin yet.
Converting trt_select as custom op: SelectPlugin
Warning: No conversion function registered for layer: FillPlugin yet.
Converting trt_fill as custom op: FillPlugin
Warning: No conversion function registered for layer: EqualPlugin yet.
Converting trt_equal as custom op: EqualPlugin

会出什么问题呢?

EN

回答 1

Stack Overflow用户

发布于 2019-08-16 10:19:21

我的解决方案是正确的。即使插件被成功映射到Tensorflow操作,仍然会产生警告。我们可以通过在插件API中打印来检查正在运行的TensorRT引擎中的插件是否加载成功。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57414905

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档