首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为训练和验证指定单独的tf.summary file_writers

为训练和验证指定单独的tf.summary file_writers
EN

Stack Overflow用户
提问于 2020-11-28 04:12:19
回答 1查看 35关注 0票数 0

随着训练的进行,我想要可视化一些标量,如训练损失、val损失(以及其他)。我正在使用带有tf2.3的tf.Keras。然而,我无法为训练和验证指定单独的文件写入器,这导致tensorboard中的图形全部损坏。

我的拉力板看起来像这样:Tensorboard output我还会在每次训练之前清理日志。所以这不是来自之前的运行。问题是我只能设置一个默认编写器。那么,如何根据当前是否正在通过损失函数运行培训或评估来切换编写器?

伪代码:

代码语言:javascript
复制
def loss_fn():
  ..calculate loss..
  tf.summary.scalar('loss', loss)
def train():
  writer = tf.summary.create_file_writer(os.path.join(args.training_folder, 'logs')
  writer.set_as_default()
  model = create_model()
  model.compile(*arguments here*)
  model.fit(*arguments here*)
EN

回答 1

Stack Overflow用户

发布于 2020-11-28 04:32:25

修复很简单。这可以通过在keras回调的on_train_begin()和on_test_begin()方法中设置不同的编写器来解决

代码语言:javascript
复制
class TensorBoardFix(tf.keras.callbacks.TensorBoard):
    
    def __init__(self, training_folder, **kwargs):
        super().__init__()
        self.train_writer = tf.summary.create_file_writer(os.path.join(training_folder, 'logs', 'train'))
        self.val_writer = tf.summary.create_file_writer(os.path.join(training_folder, 'logs', 'val'))

    def on_train_begin(self, *args, **kwargs):
        super(TensorBoardFix, self).on_train_begin(*args, **kwargs)
        tf.summary.experimental.set_step(self._train_step)
        self.train_writer.set_as_default()

    def on_test_begin(self, *args, **kwargs):
        super(TensorBoardFix, self).on_test_begin(*args, **kwargs)
        tf.summary.experimental.set_step(self._val_step)
        self.val_writer.set_as_default()
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65043356

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档