首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无法从检查点恢复: bidirectional/backward_lstm/bias

无法从检查点恢复: bidirectional/backward_lstm/bias
EN

Stack Overflow用户
提问于 2019-06-26 23:14:58
回答 1查看 301关注 0票数 1

我正在尝试用tensor2tensor创建一个简单的基于LSTM的RNN。

到目前为止,训练似乎是有效的,但我无法恢复模型。尝试这样做将抛出一个NotFoundError,指出来自LSTM的偏置节点:

代码语言:javascript
复制
NotFoundError: .. 

Key bidirectional/backward_lstm/bias not found in checkpoint

我不知道为什么会这样。

这实际上是另一个问题的变通方法,我可以使用来自tensor2tensor (https://github.com/tensorflow/tensor2tensor/issues/1616)的LSTM来解决类似的问题。

环境

代码语言:javascript
复制
$ pip freeze | grep tensor
mesh-tensorflow==0.0.5
tensor2tensor==1.12.0
tensorboard==1.12.0
tensorflow-datasets==1.0.2
tensorflow-estimator==1.13.0
tensorflow-gpu==1.12.0
tensorflow-metadata==0.9.0
tensorflow-probability==0.5.0

模型体

代码语言:javascript
复制
def body(self, features):

    inputs = features['inputs'][:,:,0,:]

    hparams = self._hparams
    problem = hparams.problem
    encoders = problem.feature_info

    max_input_length = 350
    max_output_length = 350 

    encoder = Bidirectional(LSTM(128, return_sequences=True, unroll=False), merge_mode='concat')(inputs)
    encoder_last = encoder[:, -1, :]

    decoder = LSTM(256, return_sequences=True, unroll=False)(inputs, initial_state=[encoder_last, encoder_last])

    attention = dot([decoder, encoder], axes=[2, 2])
    attention = Activation('softmax', name='attention')(attention)

    context = dot([attention, encoder], axes=[2, 1])
    concat = concatenate([context, decoder])

    return tf.expand_dims(concat, 2)

完全错误

代码语言:javascript
复制
NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key while/lstm_keras/parallel_0_4/lstm_keras/lstm_keras/body/bidirectional/backward_lstm/bias not found in checkpoint
     [[node save/RestoreV2 (defined at /home/sfalk/tmp/pycharm_project_265/asr/model/persistence.py:282)  = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

任何问题可能是什么以及如何解决这个问题?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-08-08 00:38:32

这似乎与https://github.com/tensorflow/tensor2tensor/issues/1486有关。在使用tensor2tensor从检查点恢复期间,"while“似乎位于键名的前面。这似乎是一个未解决的错误,您的意见将在github上得到感谢。

如果可以的话,我会对此发表评论,但我的名声太低了。干杯。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56776076

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档