首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >恢复培训tf.keras Tensorboard

恢复培训tf.keras Tensorboard
EN

Stack Overflow用户
提问于 2019-03-06 17:42:18
回答 3查看 3.2K关注 0票数 9

当我继续训练我的模型并在tensorboard上可视化进度时,我遇到了一些问题。

我的问题是,如何在不手动指定任何时期的情况下,从相同的步骤恢复训练?如果可能,只需加载保存的模型,它就可以以某种方式从保存的优化器读取global_step,并从那里继续训练。

我在下面提供了一些代码来重现类似的错误。

代码语言:javascript
复制
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.models import load_model

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=10, callbacks=[Tensorboard()])
model.save('./final_model.h5', include_optimizer=True)

del model

model = load_model('./final_model.h5')
model.fit(x_train, y_train, epochs=10, callbacks=[Tensorboard()])

您可以使用以下命令运行tensorboard

代码语言:javascript
复制
tensorboard --logdir ./logs
EN

回答 3

Stack Overflow用户

发布于 2019-03-06 22:02:42

您可以将函数model.fit()中的参数initial_epoch设置为您希望训练开始时的纪元编号。考虑到模型训练到索引epochs的时期(而不是epochs给出的迭代次数)。在你的例子中,如果你想再训练10个时期,它应该是:

代码语言:javascript
复制
model.fit(x_train, y_train, initial_epoch=9, epochs=19, callbacks=[Tensorboard()])

它将允许你以正确的方式在Tensorboard上可视化你的图。有关这些参数的更多信息可以在docs中找到。

票数 8
EN

Stack Overflow用户

发布于 2019-03-18 09:02:30

以下是示例代码,以备有人需要时使用。它实现了Abhinav Anand提出的想法:

代码语言:javascript
复制
mca = ModelCheckpoint(join(dir, 'model_{epoch:03d}.h5'),
                      monitor = 'loss',
                      save_best_only = False)
tb = TensorBoard(log_dir = join(dir, 'logs'),
                 write_graph = True,
                 write_images = True)
files = sorted(glob(join(fold_dir, 'model_???.h5')))
if files:
    model_file = files[-1]
    initial_epoch = int(model_file[-6:-3])
    print('Resuming using saved model %s.' % model_file)
    model = load_model(model_file)
else:
    model = nn.model()
    initial_epoch = 0
model.fit(x_train,
          y_train,
          epochs = 100,
          initial_epoch = initial_epoch,
          callbacks = [mca, tb])

nn.model()替换为您自己的定义模型的函数。

票数 4
EN

Stack Overflow用户

发布于 2019-03-07 14:08:14

这很简单。在训练模型时创建检查点,然后使用这些检查点从您离开的位置恢复训练。

代码语言:javascript
复制
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=10, callbacks=[Tensorboard()])
model.save('./final_model.h5', include_optimizer=True)

model = load_model('./final_model.h5')

callbacks = list()

tensorboard = Tensorboard()
callbacks.append(tensorboard)

file_path = "model-{epoch:02d}-{loss:.4f}.hdf5"

# now here you can create checkpoints and save according to your need
# here period is the no of epochs after which to save the model every time during training
# another option is save_weights_only, for your case it should be false
checkpoints = ModelCheckpoint(file_path, monitor='loss', verbose=1, period=1, save_weights_only=False)
callbacks.append(checkpoints)

model.fit(x_train, y_train, epochs=10, callbacks=callbacks)

在此之后,只需加载检查点,您就可以再次从该检查点恢复训练

代码语言:javascript
复制
model = load_model(checkpoint_of_choice)
model.fit(x_train, y_train, epochs=10, callbacks=callbacks)

你就完了。

如果你对此有更多的问题,请告诉我。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55019885

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档