首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow rnn模型路径

tensorflow rnn模型路径
EN

Stack Overflow用户
提问于 2015-11-29 15:35:40
回答 1查看 1.4K关注 0票数 2

我已经使用Tensorflow训练了语言模型,如本tutorial所示

为了进行训练,我使用了以下命令。

代码语言:javascript
复制
 bazel-bin/tensorflow/models/rnn/ptb/ptb_word_lm   --data_path=./simple-examples/data/  --model small

训练是成功的,最后的o/p如下。

代码语言:javascript
复制
Epoch: 13 Train Perplexity: 37.196
Epoch: 13 Valid Perplexity: 124.502
Test Perplexity: 118.624

但我仍然对训练模型存储在哪里以及如何使用它感到困惑。

EN

回答 1

Stack Overflow用户

发布于 2015-11-29 16:56:03

演示代码可能不包括保存模型的功能;您可能希望显式地使用tf.train.Saver将变量保存到检查点以及从检查点恢复变量。

参见docexamples

根据医生的说法,这很简单。在下面的例子中,我保存了模型中的所有变量。相反,您可以通过遵循examples来选择要保存的变量。

代码语言:javascript
复制
# ... 
tf.initialize_all_variables().run()
####################################################
# Add ops to save and restore all the variables.
####################################################
saver = tf.train.Saver()

for i in range(config.max_max_epoch):
  lr_decay = config.lr_decay ** max(i - config.max_epoch, 0.0)
  m.assign_lr(session, config.learning_rate * lr_decay)

  print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
  train_perplexity = run_epoch(session, m, train_data, m.train_op,
                               verbose=True)
  print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
  valid_perplexity = run_epoch(session, mvalid, valid_data, tf.no_op())
  print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))

  ####################################################
  # Save the variables to disk.
  ####################################################
  save_path = saver.save(session, "/tmp/model.epoch.%03d.ckpt" % (i + 1))
  print("Model saved in file: %s" % save_path)
  # ....

在我的示例中,每个检查点文件的磁盘大小为18.61M (--model small)。

关于如何使用模型,只需按照doc从保存的文件中恢复检查点即可。然后,如何使用它取决于你的意愿。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/33980496

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档