我指的是关于使用tensorflow构建注意模型的这文章。
我正在尝试使用google在我的数据集上训练一个类似的模型。由于colab和我的大型数据集的会话限制,我需要保存模型状态并恢复它以恢复培训。
但是,在保存参数时,我无法恢复模型。我保存了输入和目标标记器、模型检查点,甚至是输入和输出张量。然而,每次我使用checkpoint.restore并继续训练该模型时,它都会以很高的损失(等于随机权重)恢复训练。
在使用转换函数保存某些测试数据之前,我总是测试我的模型,它生成一行摘要。但是,当我恢复模型并在转换函数上运行一些样本数据时,我只得到一个标记作为输出(就好像它是一个新初始化的模型)。
这是我的密码
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
manager = tf.train.CheckpointManager(checkpoint, 'checkpoint_dir', max_to_keep=1)训练步骤是
EPOCHS = 50
for epoch in range(EPOCHS):
start = time.time()
enc_hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch, (inp, targ)) in tqdm(enumerate(dataset.take(steps_per_epoch))):
batch_loss = train_step(inp, targ, enc_hidden)
total_loss += batch_loss
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
batch,
batch_loss.numpy()))
# saving (checkpoint) the model every 3 epochs
if (epoch + 1) % 3 == 0:
manager.save()
print('Epoch {} Loss {:.4f}'.format(epoch + 1,
total_loss / steps_per_epoch))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))我靠做来恢复
checkpoint.restore('ckpt-ckptnumber.index')我使用泡菜保存令牌器(输入和输出)。
with open('inp_tokenizer.pickle', 'wb') as handle:
pickle.dump(inp_lang_tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)我使用numpy.save()保存张量
np.save('X.npy', input_tensor)发布于 2020-04-13 07:14:47
您必须使用ckpt.restore(“./tf_ckpt/ckpt-10”)之类的全部东西。请查一下https://www.tensorflow.org/guide/checkpoint。
https://datascience.stackexchange.com/questions/72206
复制相似问题