首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在自定义数据集上训练Faster-RCNN模型时加载检查点

在自定义数据集上训练Faster-RCNN模型时加载检查点
EN

Stack Overflow用户
提问于 2021-10-24 03:54:20
回答 1查看 138关注 0票数 6

我正在尝试使用Faster-RCNN架构加载检查点和填充模型权重(准确地说,是来自hereFaster R-CNN ResNet50 V1 640x640。我正在尝试加载这个网络的权重,类似于它在example notebook for RetinaNet中的操作方式,其中它们执行以下操作:

代码语言:javascript
复制
fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
)

fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor
)

ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

我正在尝试获得一个类似的检查点加载机制,用于我想使用的更快的RCNN网络,但像_base_tower_layers_for_heads_box_prediction_head这样的属性只存在于示例中使用的体系结构中,而不存在于其他任何情况下。

对于我的特定用例,我也找不到关于使用Checkpoint填充模型的哪些部分的文档。将非常感谢任何关于如何处理这一点的帮助!

EN

回答 1

Stack Overflow用户

发布于 2021-11-01 08:00:12

正如你所说的,你遇到的主要问题是,你没有层张量,你想要在它上面进行迁移学习。这是动物园最初实施的更快的R-CNN ResNet50 V1 640x640副本的一部分。他们没有命名这些层,或者他们确实命名了这些层,但并没有发布这些名称(或代码)。为了解决这个问题,你需要找出你想要保留的层和想要重新学习的层。您可以使用(ref)打印出网络中的所有图层:

代码语言:javascript
复制
[n.name for n in tf.get_default_graph().as_graph_def().node]

可以手动添加图层名称,但tf会为每个节点保留默认名称。这个列表可能很长,很累人,但是你需要找到开始迁移学习的张量。因此,您需要遵循列表,并尝试了解您想要冻结哪些层,以及您想要继续学习过程。冻结层(ref):

代码语言:javascript
复制
if layer.name == 'layer_name':
    layer.trainable = False
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69693757

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档