首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras模型输入语法,使用加号(+)

Keras模型输入语法,使用加号(+)
EN

Stack Overflow用户
提问于 2018-04-20 08:46:06
回答 1查看 393关注 0票数 0

我的主要问题是在声明keras模型输入/outputs时使用"+“,这与普通的[input1, input2],[output1,output2]方法有什么不同?例如,在这个最小lstm seq2seq推理模型中:在对模型进行了培训之后,作者定义了推理模型:

代码语言:javascript
复制
decoder_model = Model(
[decoder_inputs] + decoder_states_inputs,
[decoder_outputs] + decoder_states)

我在keras文档中找不到这样的例子。

如果您想知道我的问题的具体内容:我正在编写一个用于特征提取的CNN,-> ->层(提供状态),-> GRU层,->稠密层架构,用于对图像执行OCR。我的原型训练了很好的,但是当我试图声明与上面的例子类似的推理模型时,添加不同维度的输入时会出现一个错误,但是上面的例子也有不同的维度。

以下是我的推理模型:

代码语言:javascript
复制
decoder_state_input = Input(shape=(deencoder_dims,))

decoder_outputs, state_h = decoder_gru(
    decoder_input, initial_state=decoder_state_input)

decoder_outputs = decoder_dense(decoder_outputs)

decoder_model = Model(
    [decoder_input] + decoder_state_input,
    [decoder_outputs] + state_h)

有以下输入/产出:

代码语言:javascript
复制
decoder_input = (None,83) (num of decoder tokens)
decoder_state_input = (None,100) (states)
decoder_outputs = (None,83) (tokens)
decoder_states = (None,100) (states)

这会导致错误:InvalidArgumentError: Dimensions must be equal, but are 83 and 100 for 'add_1' (op: 'Add') with input shapes: [1,?,?,83], [?,100].不确定1/ 1,?,?,83来自哪里.

这是示例中的代码:

代码语言:javascript
复制
decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states)

运行得很好。

我不明白为什么要这样声明输入,也不知道文档中有什么可以解释它。我知道当我尝试这样做的时候会弹出错误,因为输入是不同的尺寸,但是这个例子不会发生同样的情况!?它还具有不同大小的输入,下面是示例中对推理模型的总结:

代码语言:javascript
复制
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_12 (InputLayer)           (None, None, 83)     0                                            
__________________________________________________________________________________________________
input_13 (InputLayer)           (None, 100)          0                                            
__________________________________________________________________________________________________
input_14 (InputLayer)           (None, 100)          0                                            
__________________________________________________________________________________________________
lstm_4 (LSTM)                   [(None, None, 100),  73600       input_12[0][0]                   
                                                                 input_13[0][0]                   
                                                                 input_14[0][0]                   
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, None, 83)     8383        lstm_4[1][0]                     
==================================================================================================

谢谢你的洞察力

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-04-23 09:19:45

好的,基本上,"+“是将keras添加合并层应用到输入。这显然是向现有图添加新输入的唯一方法,这个问题给了我一个提示。第二个问题是,您不能添加不同维度的输入,但是可以通过将输入声明为list input1 = [Input(something)]来规避这一问题,在model = Model(Input = ...)上声明输入时也不能这样做,我不知道为什么,但它对我无效。

我的工作代码如下所示:

代码语言:javascript
复制
decoder_state_input = [Input(shape=(deencoder_dims,))]
decoder_outputs, state_h = decoder_gru(
    decoder_input, initial_state=decoder_state_input)
state_h_out = [state_h]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
    [decoder_input]+decoder_state_input,
    [decoder_outputs]+state_h_out
)
model.summary()

注意,输入被声明为列表。

最后:

  • 声明新的输入/输出附加到现有的图形:需要声明为列表并与以前的输入合并(不知道keras如何将它们分开.)
  • 角角并不总是有意义的:)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49937629

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档