首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >生成式对抗网络中判别器损失不变

生成式对抗网络中判别器损失不变
EN

Stack Overflow用户
提问于 2020-11-07 01:23:47
回答 1查看 711关注 0票数 0

我正在尝试用pix2pix GAN生成器和Unet作为鉴别器来训练GAN。但经过一段时间后,我的鉴别器损失停止变化,并停留在5.546左右的值。这是GAN训练的好兆头还是坏兆头。

这是我的损失计算:

代码语言:javascript
复制
def discLoss(rValid, rLabel, fValid, fLabel):
    # validity loss
    bce =     tf.keras.losses.BinaryCrossentropy(from_logits=True,label_smoothing=0.1)
    # classifier loss
    scce =     tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    # Loss for real
    real_dloss = (bce(tf.ones_like(rValid), rValid) + scce(label, rLabel))#/2
    # Loss for fake
    fake_dloss = (bce(tf.zeros_like(fValid), fValid) + scce(label, fLabel))#/2
    # Total discriminator loss
    d_loss = (real_dloss + fake_dloss)# / 2
    return d_loss

def generator_loss(disc_generated_output, gen_output, target):
  loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
  LAMBDA = 100
  # mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss

这是我的训练步骤:

代码语言:javascript
复制
def train_step(img1, img2, label, generator,discriminator,generator_optimizer,discriminator_optimizer):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    fImg = generator([img1, label], training=True)
    rValid, rLabel = discriminator(img2, training=True)
    fValid, fLabel = discriminator(fImg, training=True)

    disc_loss = discLoss(rValid, rLabel, fValid, fLabel)
    gen_loss = generator_loss(fValid, fImg, img2)
    # genLoss(label, rValid, rLabel, fValid, fLabel)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return tf.math.reduce_mean(gen_loss).numpy(), disc_loss.numpy()
EN

回答 1

Stack Overflow用户

发布于 2020-11-08 04:50:27

这个损失太高了。你需要注意G和D的学习速度是一致的。请访问此问题并提供相关链接:How to balance the generator and the discriminator performances in a GAN?

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64719046

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档