随着训练的进行,我想要可视化一些标量,如训练损失、val损失(以及其他)。我正在使用带有tf2.3的tf.Keras。然而,我无法为训练和验证指定单独的文件写入器,这导致tensorboard中的图形全部损坏。
我的拉力板看起来像这样:Tensorboard output我还会在每次训练之前清理日志。所以这不是来自之前的运行。问题是我只能设置一个默认编写器。那么,如何根据当前是否正在通过损失函数运行培训或评估来切换编写器?
伪代码:
def loss_fn():
..calculate loss..
tf.summary.scalar('loss', loss)
def train():
writer = tf.summary.create_file_writer(os.path.join(args.training_folder, 'logs')
writer.set_as_default()
model = create_model()
model.compile(*arguments here*)
model.fit(*arguments here*)发布于 2020-11-28 04:32:25
修复很简单。这可以通过在keras回调的on_train_begin()和on_test_begin()方法中设置不同的编写器来解决
class TensorBoardFix(tf.keras.callbacks.TensorBoard):
def __init__(self, training_folder, **kwargs):
super().__init__()
self.train_writer = tf.summary.create_file_writer(os.path.join(training_folder, 'logs', 'train'))
self.val_writer = tf.summary.create_file_writer(os.path.join(training_folder, 'logs', 'val'))
def on_train_begin(self, *args, **kwargs):
super(TensorBoardFix, self).on_train_begin(*args, **kwargs)
tf.summary.experimental.set_step(self._train_step)
self.train_writer.set_as_default()
def on_test_begin(self, *args, **kwargs):
super(TensorBoardFix, self).on_test_begin(*args, **kwargs)
tf.summary.experimental.set_step(self._val_step)
self.val_writer.set_as_default()https://stackoverflow.com/questions/65043356
复制相似问题