首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >给负损失的变分AutoEncoder

给负损失的变分AutoEncoder
EN

Data Science用户
提问于 2019-04-05 14:25:07
回答 1查看 3.4K关注 0票数 2

我正在学习变分自动编码器,并在keras中实现了一个简单的例子,下面是模型摘要。我从Francois的一篇博客文章中复制了损失函数,现在我的损失非常严重。我在这里错过了什么?

代码语言:javascript
复制
    Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 224)]        0
__________________________________________________________________________________________________
encoding_flatten (Flatten)      (None, 224)          0           input_1[0][0]
__________________________________________________________________________________________________
encoding_layer_2 (Dense)        (None, 256)          57600       encoding_flatten[0][0]
__________________________________________________________________________________________________
encoding_layer_3 (Dense)        (None, 128)          32896       encoding_layer_2[0][0]
__________________________________________________________________________________________________
encoding_layer_4 (Dense)        (None, 64)           8256        encoding_layer_3[0][0]
__________________________________________________________________________________________________
encoding_layer_5 (Dense)        (None, 32)           2080        encoding_layer_4[0][0]
__________________________________________________________________________________________________
encoding_layer_6 (Dense)        (None, 16)           528         encoding_layer_5[0][0]
__________________________________________________________________________________________________
encoder_mean (Dense)            (None, 16)           272         encoding_layer_6[0][0]
__________________________________________________________________________________________________
encoder_sigma (Dense)           (None, 16)           272         encoding_layer_6[0][0]
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 16)           0           encoder_mean[0][0]
                                                                 encoder_sigma[0][0]
__________________________________________________________________________________________________
decoder_layer_1 (Dense)         (None, 16)           272         lambda[0][0]
__________________________________________________________________________________________________
decoder_layer_2 (Dense)         (None, 32)           544         decoder_layer_1[0][0]
__________________________________________________________________________________________________
decoder_layer_3 (Dense)         (None, 64)           2112        decoder_layer_2[0][0]
__________________________________________________________________________________________________
decoder_layer_4 (Dense)         (None, 128)          8320        decoder_layer_3[0][0]
__________________________________________________________________________________________________
decoder_layer_5 (Dense)         (None, 256)          33024       decoder_layer_4[0][0]
__________________________________________________________________________________________________
decoder_mean (Dense)            (None, 224)          57568       decoder_layer_5[0][0]
==================================================================================================
Total params: 203,744
Trainable params: 203,744
Non-trainable params: 0
__________________________________________________________________________________________________
Train on 3974 samples, validate on 994 samples
Epoch 1/10
3974/3974 [==============================] - 3s 677us/sample - loss: -28.1519 - val_loss: -33.5864
Epoch 2/10
3974/3974 [==============================] - 1s 346us/sample - loss: -137258.8175 - val_loss: -3683802.1489
Epoch 3/10
3974/3974 [==============================] - 1s 344us/sample - loss: -14543022903.6056 - val_loss: -107811177469.9396
Epoch 4/10
3974/3974 [==============================] - 1s 363us/sample - loss: -3011718676570.7012 - val_loss: -13131454938476.6816
Epoch 5/10
3974/3974 [==============================] - 1s 350us/sample - loss: -101442605943572.4844 - val_loss: -322685056398605.9375
Epoch 6/10
3974/3974 [==============================] - 1s 344us/sample - loss: -1417424385529640.5000 - val_loss: -3687688508198145.5000
Epoch 7/10
3974/3974 [==============================] - 1s 358us/sample - loss: -11794297368126698.0000 - val_loss: -26632844827070784.0000
Epoch 8/10
3974/3974 [==============================] - 1s 339us/sample - loss: -69508229806130784.0000 - val_loss: -141312065640756336.0000
Epoch 9/10
3974/3974 [==============================] - 1s 345us/sample - loss: -319838384005810432.0000 - val_loss: -599553350073361152.0000
Epoch 10/10
3974/3974 [==============================] - 1s 342us/sample - loss: -1221653451351326464.0000 - val_loss: -2147128507956525312.0000

潜在样本漏斗:

代码语言:javascript
复制
def sampling(self,args):
    """Reparameterization trick by sampling fr an isotropic unit Gaussian.
    # Arguments
        args (tensor): mean and log of variance of Q(z|X)
    # Returns
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    set = tf.shape(z_mean)[0]
    batch = tf.shape(z_mean)[1]
    dim = tf.shape(z_mean)[-1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = tf.random.normal(shape=(set, dim))#tfp.distributions.Normal(mean=tf.zeros(shape=(batch, dim)),loc=tf.ones(shape=(batch, dim)))
    return z_mean + (z_log_var * epsilon)

损失基金:

代码语言:javascript
复制
def vae_loss(self,input, x_decoded_mean):
    xent_loss = tf.reduce_mean(tf.keras.backend.binary_crossentropy(input, x_decoded_mean))
    kl_loss = -0.5 * tf.reduce_sum(tf.square(self.encoded_mean) + tf.square(self.encoded_sigma) - tf.math.log(tf.square(self.encoded_sigma)) - 1, -1)
    return xent_loss + kl_loss

另一个vae_loss实现:

代码语言:javascript
复制
def vae_loss(self,input, x_decoded_mean):
    gen_loss = tf.reduce_sum(tf.keras.backend.binary_crossentropy(input, x_decoded_mean))
    #gen_loss = tf.losses.mean_squared_error(input,x_decoded_mean)
    kl_loss = -0.5 * tf.reduce_sum(1 + self.encoded_sigma - tf.square(self.encoded_mean) - tf.exp(self.encoded_sigma), -1)
    return tf.reduce_mean(gen_loss + kl_loss)

log_sigma kl_loss:

代码语言:javascript
复制
kl_loss = 0.5 * tf.reduce_sum(tf.square(self.encoded_mean) + tf.square(tf.exp(self.encoded_sigma)) - self.encoded_sigma - 1, axis=-1)
EN

回答 1

Data Science用户

发布于 2019-08-30 12:52:22

您的数据是二进制数据。我认为binary_crossentropy损失适用于二进制输入。MNIST中的所有数据都是二进制的。如果您的输入是连续的,如彩色图像,尝试MSE丢失或其他。看看这里,https://github.com/Lasagne/Recipes/issues/54

票数 0
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/48697

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档