首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在Keras中组合模型(输出)

在Keras中组合模型(输出)
EN

Stack Overflow用户
提问于 2021-11-09 14:58:33
回答 1查看 22关注 0票数 1

我正在尝试构建以下论文中介绍的网络:link

基本上,自动编码器是另外两个模型的组合,embedder和recovery如下所述:

代码语言:javascript
复制
X = Input(shape=[TIMESTEPS, FEAT], batch_size=BATCH_SIZE, name='RealData')

def recovery(self, H):

    L1 = LSTM(HIDDEN_NODES, return_sequences=True)(H)
    L2 = LSTM(HIDDEN_NODES, return_sequences=True)(L1)
    L3 = LSTM(HIDDEN_NODES, return_sequences=True)(L2)  
    O = Dense(OUTPUT_NODES, activation='sigmoid', name='OUTPUT')(L3)

    return O

def embedder(self, X):
    L1 = LSTM(HIDDEN_NODES, return_sequences=True)(X)
    L2 = LSTM(HIDDEN_NODES, return_sequences=True)(L1)
    L3 = LSTM(HIDDEN_NODES, return_sequences=True)(L2)      
    O = Dense(HIDDEN_NODES, activation='sigmoid')(L3)
    return O 

最后,将它们与以下几行结合起来:

代码语言:javascript
复制
    H = self.embedder(X) 

    X_tilde = self.recovery(H)

    self.autoencoder = Model(inputs=X, outputs=X_tilde)

显示自动编码器的.summary,我有以下内容:

并引发以下错误:

代码语言:javascript
复制
var_list = self.embedder.trainable_variables + self.recovery.trainable_variables
AttributeError: 'function' object has no attribute 'trainable_variables'

我哪里做错了?

我正在重现的基线代码可以在here中找到

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-11-09 15:15:08

问题是embedderrecovery不是trainable_variables的模型。这两个函数只是返回最后一层的输出。也许可以试试这样的东西:

代码语言:javascript
复制
import tensorflow as tf

X = tf.keras.layers.Input(shape=[10, 10], batch_size=2, name='RealData')

def recovery():
    model = tf.keras.Sequential([
      tf.keras.layers.LSTM(10, return_sequences=True),
      tf.keras.layers.LSTM(10, return_sequences=True),
      tf.keras.layers.LSTM(10, return_sequences=True),
      tf.keras.layers.Dense(10, activation='sigmoid', name='OUTPUT')
    ])
    return model

def embedder():
    model = tf.keras.Sequential([
      tf.keras.layers.LSTM(10, return_sequences=True),
      tf.keras.layers.LSTM(10, return_sequences=True),
      tf.keras.layers.LSTM(10, return_sequences=True),
      tf.keras.layers.Dense(10, activation='sigmoid')
    ])
    return model 


embedder_model = embedder() 
H = embedder_model(X)

recovery_model = recovery() 
X_tilde = recovery_model(H)

autoencoder = tf.keras.Model(inputs=X, outputs=X_tilde)

var_list = embedder_model.trainable_variables + embedder_model.trainable_variables

tf.print(var_list[:2])
代码语言:javascript
复制
[[[0.343916416 0.310338378 0.34440577 ... 0.0633761585 0.0405358076 0.276733816]
 [0.245998859 0.197870493 0.0333348215 ... -0.136249736 0.271893084 -0.0605607331]
 [-0.290359527 0.240957797 0.117871583 ... 0.172593892 0.113803834 0.0506341457]
 ...
 [0.15672195 -0.161336392 -0.13484776 ... 0.306486845 -0.0707859397 0.245753765]
 [0.00567743182 0.181330919 0.206510961 ... 0.0141542256 0.205756843 -0.074064374]
 [0.299010575 -0.236641362 0.272176802 ... 0.0658480823 0.04648754 -0.342863292]], [[0.224076748 -0.112819761 -0.114276126 ... -0.190908 -0.282466382 -0.0711786151]
 [-0.0689174235 0.203702673 -0.248280779 ... -0.0145524191 0.202952 0.0797807127]
 [0.0919017 0.108805738 -0.124872617 ... 0.26839748 0.21041657 0.251440644]
 ...
 [-0.117122218 -0.0974424109 -0.17138055 ... 0.150875479 0.0454813093 0.0753096]
 [-0.115990438 -0.360190183 -0.0988362879 ... -0.0655761734 0.11425022 0.0291871373]
 [-0.00164104556 -0.0442082509 0.135109842 ... -0.182655513 -0.0121813752 0.0497299805]]]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69900375

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档