首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用训练好的模型层在keras中创建另一个模型

使用训练好的模型层在keras中创建另一个模型
EN

Stack Overflow用户
提问于 2019-05-06 13:10:18
回答 1查看 694关注 0票数 2

我用Keras创建了一个模型,如下所示:

代码语言:javascript
复制
    m = Sequential()
    m.add(Dense(912, activation='relu', input_shape=(943, 1)))
    m.add(Dense(728, activation='relu'))
    m.add(Dense(528, activation='relu'))
    m.add(Flatten())
    m.add(Dense(500, activation='relu', name="bottleneck"))
    m.add(Dense(528, activation='relu'))
    m.add(Dense(728, activation='relu'))
    m.add(Dense(943, activation='linear'))

    m.compile(loss='mean_squared_error', optimizer='SGD')
    m.summary()

现在我想把bottleneck层添加到下面的创建网络中:

代码语言:javascript
复制
    model = Sequential()
    model.add(Dense(930, activation='relu', input_shape=(943, 1)))
    model.add(Flatten())
    model.add(m.get_layer('bottleneck'))
    model.add(m.get_layer('bottleneck'))
    model.add(m.get_layer('bottleneck'))
    model.add(m.get_layer('bottleneck'))
    model.add(Flatten())
    model.add(Dense(100, activation='linear'))

但是在训练模型m之后,在启动错误时抛出错误:

代码语言:javascript
复制
ValueError: Input 0 is incompatible with layer bottleneck: expected axis -1 of input shape to have value 497904 but got shape (None, 876990)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-05-06 15:41:28

错误消息试图告诉您,与第一个模型相比,第二个模型中“瓶颈”层的输入形状不同。

为了重用一个层,你需要匹配该层的输入数量。在您的示例中,第一个模型对该层有497904个输入,但您正尝试在下一个模型中使用它,该模型具有一个具有876990个输入的输入层。

我怀疑你想要更多这样的东西(注意,我在每种情况下都立即变平了,这样我们就可以更好地掌握每一层的输入数量):

代码语言:javascript
复制
m = Sequential()
m.add(Flatten(input_shape=(943, 1)))
m.add(Dense(912, activation='relu'))
m.add(Dense(728, activation='relu'))
m.add(Dense(528, activation='relu'))
m.add(Dense(500, activation='relu', name="bottleneck"))
m.add(Dense(528, activation='relu'))
m.add(Dense(728, activation='relu'))
m.add(Dense(943, activation='linear'))

m.compile(loss='mean_squared_error', optimizer='SGD')
m.summary()
代码语言:javascript
复制
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten (Flatten)            (None, 943)               0         
_________________________________________________________________
dense (Dense)                (None, 912)               860928    
_________________________________________________________________
dense_1 (Dense)              (None, 728)               664664    
_________________________________________________________________
dense_2 (Dense)              (None, 528)               384912    
_________________________________________________________________
bottleneck (Dense)           (None, 500)               264500    
_________________________________________________________________
dense_3 (Dense)              (None, 528)               264528    
_________________________________________________________________
dense_4 (Dense)              (None, 728)               385112    
_________________________________________________________________
dense_5 (Dense)              (None, 943)               687447    
=================================================================
Total params: 3,512,091
Trainable params: 3,512,091
Non-trainable params: 0

请注意,我们的瓶颈层的输入具有形状(None,528)。现在,在第二个模型中,我们可以这样做:

代码语言:javascript
复制
model = Sequential()
model.add(Dense(930, activation='relu', input_shape=(943, 1)))
model.add(Flatten())
model.add(Dense(528, activation='relu'))
model.add(m.get_layer('bottleneck'))
model.add(Flatten())
model.add(Dense(100, activation='linear'))
model.summary()
代码语言:javascript
复制
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              (None, 943, 930)          1860      
_________________________________________________________________
flatten_3 (Flatten)          (None, 876990)            0         
_________________________________________________________________
dense_10 (Dense)             (None, 528)               463051248 
_________________________________________________________________
bottleneck (Dense)           (None, 500)               264500    
_________________________________________________________________
flatten_4 (Flatten)          (None, 500)               0         
_________________________________________________________________
dense_11 (Dense)             (None, 100)               50100     
=================================================================
Total params: 463,367,708
Trainable params: 463,367,708
Non-trainable params: 0
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55999097

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档