首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在使用预训练的模型和配置文件时,如何停止基于损失的训练?

在使用预训练的模型和配置文件时,如何停止基于损失的训练?
EN

Stack Overflow用户
提问于 2020-09-03 21:30:36
回答 1查看 225关注 0票数 1

我正在使用一个更快的RCNN模型来训练一个对象检测器,使用的是Pipeline配置文件。我知道可以通过直接取消(ctrl+c)来停止培训。我希望训练根据损失值自动停止。如何做到这一点?我知道keras回调可以在监控时期时使用。在使用配置文件和预先训练的模型(用于监控步骤)时,是否有这样的选项。谢谢。

EN

回答 1

Stack Overflow用户

发布于 2020-09-04 02:28:28

这可能只是一个技巧,但我找到了我的问题的解决方案。对象检测器需要安装tf_slim包。在tf_slim包中,有一个名为learning.py的模块。它的完整路径可能如下所示:/usr/local/lib/python3.6/site-packages/tf_slim/learning.py在这里,在learning.py中,从第764行开始,代码如下:

代码语言:javascript
复制
try:
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

我编写了一个小的if语句来检查total_loss的最后五个值的最大值,如果低于某个阈值(在本例中为3),则将其设为should_stop True。如下所示:

代码语言:javascript
复制
try:
  total_loss_list = []
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    total_loss_list.append(total_loss)
    if len(total_loss_list) > 5:
      if max(total_loss_list[-5:]) < 3:
        should_stop = True
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
  # OutOfRangeError is thrown when epoch limit per
  # tf.compat.v1.train.limit_epochs is reached.
  logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

如果损失值在五个步骤中持续低于3,则训练将停止。这样做的缺点是,必须更改tf_slim的包分发版本。每次处理新的目标检测问题时,这个阈值损失值都会发生变化。更好的方法是使用配置文件,您可以在其中提供阈值损失值。但我现在就到此为止吧。如果其他人有更好的解决方案,请分享。我希望这对某些人有帮助。谢谢!

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63724730

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档