首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >‘in’对象在vae.fit()函数中不可订阅

‘in’对象在vae.fit()函数中不可订阅
EN

Stack Overflow用户
提问于 2022-03-28 18:58:10
回答 1查看 90关注 0票数 1

我正在开发一个VAE,使用的是:数据集

我使用了keras 教程代码,并开发了自己的编码器和解码器,问题是当我运行vae.fit()时,我得到了'int' object is not subscriptable。我做错了什么?

代码语言:javascript
复制
df = pd.read_csv('local path')
xtrain, xtest = train_test_split(df, test_size=0.2)

编码器:

代码语言:javascript
复制
def encoder(input_shape):
   inputs = keras.Input(shape=input_shape)
   x = layers.Dense(128, activation='relu')(inputs)
   x = layers.Dense(128, activation='relu')(x)
   z_mean = layers.Dense(2, name='z_mean')(x)
   z_log_var = layers.Dense(2, name='z_log_var')(x)
   z = Sampling()([z_mean, z_log_var])
   encoder = keras.Model(inputs, [z_mean, z_log_var, z], name='encoder')
   encoder.summary()
   return encoder

解码器:

代码语言:javascript
复制
def decoder(input_shape):
   inputs = keras.Input(shape=input_shape)
   x = layers.Dense(128, activation='relu')(inputs)
   x = layers.Dense(128, activation='relu')(x)
   outputs = layers.Dense(input_shape[0], activation='sigmoid')(x)
   decoder = keras.Model(inputs, outputs, name='decoder')
   decoder.summary()
   return decoder

VAE级:

代码语言:javascript
复制
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
    super(VAE, self).__init__(**kwargs)
    self.encoder = encoder
    self.decoder = decoder
    self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
    self.reconstruction_loss_tracker = keras.metrics.Mean(
        name="reconstruction_loss"
    )
    self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

@property
def metrics(self):
    return [
        self.total_loss_tracker,
        self.reconstruction_loss_tracker,
        self.kl_loss_tracker,
    ]

def train_step(self, data):
    with tf.GradientTape() as tape:
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
            )
        )
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss
    grads = tape.gradient(total_loss, self.trainable_weights)
    self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
    self.total_loss_tracker.update_state(total_loss)
    self.reconstruction_loss_tracker.update_state(reconstruction_loss)
    self.kl_loss_tracker.update_state(kl_loss)
    return {
        "loss": self.total_loss_tracker.result(),
        "reconstruction_loss": self.reconstruction_loss_tracker.result(),
        "kl_loss": self.kl_loss_tracker.result(),
    }

这就是我得到错误的地方:

代码语言:javascript
复制
data = np.concatenate([xtrain.values, xtest.values])

vae = VAE(encoder(data.shape[1]), 
decoder(data.shape[1]))
vae.compile(optimizer="adam", 
loss="binary_crossentropy")
vae.fit(data, epochs=10, batch_size=32, 
validation_split=0.2)

完整回溯:

代码语言:javascript
复制
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
c:\Users\User\Documents\Github\Generative-Models\TFG\VAE.ipynb Cell 9' in <cell line: 3>()
  1 data = np.concatenate([xtrain.values, xtest.values])
  ----> 3 vae = VAE(encoder(data.shape[1]), decoder(data.shape[1]))
  4 vae.compile(optimizer="adam", loss="binary_crossentropy")
  5 vae.fit(data, epochs=10, batch_size=32, validation_split=0.2)

  c:\Users\User\Documents\Github\Generative-Models\TFG\VAE.ipynb Cell 7' in 
  decoder(input_shape)
  3 x = layers.Dense(128, activation='relu')(inputs)
  4 x = layers.Dense(128, activation='relu')(x)
  ----> 5 outputs = layers.Dense(input_shape[0], activation='sigmoid')(x)
  6 decoder = keras.Model(inputs, outputs, name='decoder')
  7 decoder.summary()

  TypeError: 'int' object is not subscriptable

我该换什么?帮助是非常感谢的。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-28 19:03:08

encoderdecoder函数需要一个input_shape序列。但是有了

代码语言:javascript
复制
vae = VAE(
    encoder(data.shape[1]), 
    decoder(data.shape[1])
)

您正在传递int值。

您可以通过传递一个int值序列来修复这个问题。例如,用

代码语言:javascript
复制
vae = VAE(
    encoder(data.shape[1:]), 
    decoder(data.shape[1:])
)

这假设数据的形状是(samples, features)。那么你的input_shape将是(features,)

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71652404

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档