首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何恢复LSTM层

如何恢复LSTM层
EN

Stack Overflow用户
提问于 2017-07-17 22:12:49
回答 1查看 622关注 0票数 7

如果我能在挽救和修复LSTM方面得到一些帮助,我会非常感激的。

我有这个LSTM层-

代码语言:javascript
复制
# LSTM cell
cell = tf.contrib.rnn.LSTMCell(n_hidden)
output, current_state = tf.nn.dynamic_rnn(cell, word_vectors, dtype=tf.float32)

outputs = tf.transpose(output, [1, 0, 2])
last = tf.gather(outputs, int(outputs.get_shape()[0]) - 1)

# Saver function
saver = tf.train.Saver()
saver.save(sess, 'test-model')

保护程序保存模型,并允许我保存和恢复LSTM的权重和偏差。但是,我需要恢复这个LSTM层,并为它提供一组新的输入。

为了恢复整个模型,我正在做:

代码语言:javascript
复制
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('test-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
  1. 我能用预先训练过的重量和偏差初始化一个LSTM单元吗?
  2. 如果没有,如何恢复这个LSTM层?

非常感谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-07-17 23:42:31

您已经在加载模型,因此模型的权重。您所需要做的就是使用get_tensor_by_name从图中获取任何张量,并使用它进行推理。

示例:

代码语言:javascript
复制
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('test-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))

   # Get the tensors by their variable name
   word_vec = = detection_graph.get_tensor_by_name('word_vec:0')
   output_tensor = detection_graph.get_tensor_by_name('outputs:0')

   sess.run(output_tensor, feed_dict={word_vec: ...}) 

在上面的示例中,word_vecoutputs是在创建图形期间分配给张量的名称。确保您指定了名称,以便可以按其名称调用它们。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45154459

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档