首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无法正确恢复注意力模型

无法正确恢复注意力模型
EN

Data Science用户
提问于 2020-04-12 19:52:42
回答 1查看 182关注 0票数 1

我指的是关于使用tensorflow构建注意模型的文章。

我正在尝试使用google在我的数据集上训练一个类似的模型。由于colab和我的大型数据集的会话限制,我需要保存模型状态并恢复它以恢复培训。

但是,在保存参数时,我无法恢复模型。我保存了输入和目标标记器、模型检查点,甚至是输入和输出张量。然而,每次我使用checkpoint.restore并继续训练该模型时,它都会以很高的损失(等于随机权重)恢复训练。

在使用转换函数保存某些测试数据之前,我总是测试我的模型,它生成一行摘要。但是,当我恢复模型并在转换函数上运行一些样本数据时,我只得到一个标记作为输出(就好像它是一个新初始化的模型)。

这是我的密码

代码语言:javascript
复制
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

manager = tf.train.CheckpointManager(checkpoint, 'checkpoint_dir', max_to_keep=1)

训练步骤是

代码语言:javascript
复制
EPOCHS = 50

for epoch in range(EPOCHS):
  start = time.time()

  enc_hidden = encoder.initialize_hidden_state()
  total_loss = 0

  for (batch, (inp, targ)) in tqdm(enumerate(dataset.take(steps_per_epoch))):
    batch_loss = train_step(inp, targ, enc_hidden)
    total_loss += batch_loss

    if batch % 100 == 0:
      print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                   batch,
                                                   batch_loss.numpy()))
  # saving (checkpoint) the model every 3 epochs
  if (epoch + 1) % 3 == 0:
    manager.save()

  print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                      total_loss / steps_per_epoch))
  print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

我靠做来恢复

代码语言:javascript
复制
checkpoint.restore('ckpt-ckptnumber.index')

我使用泡菜保存令牌器(输入和输出)。

代码语言:javascript
复制
with open('inp_tokenizer.pickle', 'wb') as handle:
      pickle.dump(inp_lang_tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)

我使用numpy.save()保存张量

代码语言:javascript
复制
np.save('X.npy', input_tensor)
EN

回答 1

Data Science用户

回答已采纳

发布于 2020-04-13 07:14:47

您必须使用ckpt.restore(“./tf_ckpt/ckpt-10”)之类的全部东西。请查一下https://www.tensorflow.org/guide/checkpoint

票数 0
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/72206

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档