首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >注意应用编解码器(Seq2Seq)推理模型

注意应用编解码器(Seq2Seq)推理模型
EN

Stack Overflow用户
提问于 2020-07-18 11:57:50
回答 1查看 697关注 0票数 0

你好,StackOverflow社区!

我正在尝试为seq2seq (编码-解码)模型创建一个具有注意的推理模型。这是推理模型的定义。

代码语言:javascript
复制
model = compile_model(tf.keras.models.load_model(constant.MODEL_PATH, compile=False))

encoder_input = model.input[0]
encoder_output, encoder_h, encoder_c = model.layers[1].output
encoder_state = [encoder_h, encoder_c]
encoder_model = tf.keras.Model(encoder_input, encoder_state)

decoder_input = model.input[1]
decoder = model.layers[3]
decoder_new_h = tf.keras.Input(shape=(n_units,), name='input_3')
decoder_new_c = tf.keras.Input(shape=(n_units,), name='input_4')
decoder_input_initial_state = [decoder_new_h, decoder_new_c]

decoder_output, decoder_h, decoder_c = decoder(decoder_input, initial_state=decoder_input_initial_state)
decoder_output_state = [decoder_h, decoder_c]

# These lines cause an error
context = model.layers[4]([encoder_output, decoder_output])
decoder_combined_context = model.layers[5]([context, decoder_output])
output = model.layers[6](decoder_combined_context)
output = model.layers[7](output)
# end

decoder_model = tf.keras.Model([decoder_input] + decoder_input_initial_state, [output] + decoder_output_state)
return encoder_model, decoder_model

当我运行此代码时,将出现以下错误。

代码语言:javascript
复制
ValueError: Graph disconnected: cannot obtain value for tensor Tensor("input_5:0", shape=(None, None, 20), dtype=float32) at layer "lstm_4". The following previous layers were accessed without issue: ['lstm_5']

如果我排除了一个注意块,模型将是没有任何错误的形式。

代码语言:javascript
复制
model = compile_model(tf.keras.models.load_model(constant.MODEL_PATH, compile=False))

encoder_input = model.input[0]
encoder_output, encoder_h, encoder_c = model.layers[1].output
encoder_state = [encoder_h, encoder_c]
encoder_model = tf.keras.Model(encoder_input, encoder_state)

decoder_input = model.input[1]
decoder = model.layers[3]
decoder_new_h = tf.keras.Input(shape=(n_units,), name='input_3')
decoder_new_c = tf.keras.Input(shape=(n_units,), name='input_4')
decoder_input_initial_state = [decoder_new_h, decoder_new_c]

decoder_output, decoder_h, decoder_c = decoder(decoder_input, initial_state=decoder_input_initial_state)
decoder_output_state = [decoder_h, decoder_c]

# These lines cause an error
# context = model.layers[4]([encoder_output, decoder_output])
# decoder_combined_context = model.layers[5]([context, decoder_output])
# output = model.layers[6](decoder_combined_context)
# output = model.layers[7](output)
# end

decoder_model = tf.keras.Model([decoder_input] + decoder_input_initial_state, [decoder_output] + decoder_output_state)
return encoder_model, decoder_model
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-07-18 12:32:36

我认为还需要将编码器输出作为编码器模型的输出,然后根据需要将其作为解码器模型的输入。也许这些变化能帮上忙-

代码语言:javascript
复制
model = compile_model(tf.keras.models.load_model(constant.MODEL_PATH, compile=False))
encoder_input = model.input[0]
encoder_output, encoder_h, encoder_c = model.layers[1].output
encoder_state = [encoder_h, encoder_c]
encoder_model = tf.keras.Model(inputs=[encoder_input],outputs=[encoder_state,encoder_output])

decoder_input = model.input[1]
decoder_input2 = tf.keras.Input(shape=x) #where x is the shape of encoder output
decoder = model.layers[3]
decoder_new_h = tf.keras.Input(shape=(n_units,), name='input_3')
decoder_new_c = tf.keras.Input(shape=(n_units,), name='input_4')
decoder_input_initial_state = [decoder_new_h, decoder_new_c]

decoder_output, decoder_h, decoder_c = decoder(decoder_input, initial_state=decoder_input_initial_state)
decoder_output_state = [decoder_h, decoder_c]

context = model.layers[4]([decoder_input2, decoder_output])
decoder_combined_context = model.layers[5]([context, decoder_output])
output = model.layers[6](decoder_combined_context)
output = model.layers[7](output)

decoder_model = tf.keras.Model([decoder_input,decoder_input2,decoder_input_initial_state], [output] + decoder_output_state)`  
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62968276

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档