首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在keras函数对象(例如InceptionResNetV2)中添加层

如何在keras函数对象(例如InceptionResNetV2)中添加层
EN

Stack Overflow用户
提问于 2022-09-14 21:07:31
回答 1查看 80关注 0票数 0

我试图在InceptionResNetV2 (或任何其他可以通过tf.keras.applications导入的预先培训的网络)中添加层。我知道可以将对象添加到顺序模型或函数模型中。但是,当我这样做时,我将无法访问来自各层的单个输出,以便在Grad或类似的应用程序中使用它们。

我现在正在使用下面的模型结构。它有效,它可以训练。但是,它不允许我访问关于特定输入和特定输出的InceptionResNetV2的最后一个卷积层的输出。

代码语言:javascript
复制
from tensorflow.keras import layers, models
InceptionResNetV2 = tf.keras.applications.inception_resnet_v2.InceptionResNetV2

def get_base():
    conv_base = InceptionResNetV2(weights=None, include_top=False, input_shape=(224, 224, 3))
    conv_base.trainable = False
    return(conv_base)


def get_model():
    base = get_base()

    inputs = tf.keras.Input(shape=(224, 224, 3))
    x = base(inputs, training=False)
    x = layers.Flatten()(x)
    x = layers.Dense(512, "relu")(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Dense(256, "relu")(x)
    x = layers.Dropout(0.25)(x)
    dims = layers.Dense(2, name="Valence_Arousal")(x)
    expression = layers.Dense(2, name="Emotion_Category")(x)


    model = models.Model(inputs=[inputs], outputs=[expression, dims])
    return(model)

print(get_model().summary())
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-09-15 04:16:58

在创建嵌套模型之后,扩展它们是很困难的。将input_tensor参数传递给预先训练的模型,可以得到预期的结果。

代码语言:javascript
复制
def get_model():

    inputs = tf.keras.Input(shape=(224, 224, 3))
    
    conv_base = InceptionResNetV2(weights=None, include_top=False, input_tensor = inputs)
    conv_base.trainable = False
    
    x = layers.Flatten()(conv_base.output)
    x = layers.Dense(512, "relu")(x)
    x = layers.Dropout(0.25)(x)
    x = layers.Dense(256, "relu")(x)
    x = layers.Dropout(0.25)(x)
    
    dims = layers.Dense(2, name="Valence_Arousal")(x)
    expression = layers.Dense(2, name="Emotion_Category")(x)


    model = models.Model(inputs=[inputs], outputs=[expression, dims])
    return(model)

示范摘要:

代码语言:javascript
复制
input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
conv2d (Conv2D)                (None, 111, 111, 32  864         ['input_1[0][0]']                
                                )  
...
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73723176

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档