我正在尝试用pix2pix GAN生成器和Unet作为鉴别器来训练GAN。但经过一段时间后,我的鉴别器损失停止变化,并停留在5.546左右的值。这是GAN训练的好兆头还是坏兆头。
这是我的损失计算:
def discLoss(rValid, rLabel, fValid, fLabel):
# validity loss
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True,label_smoothing=0.1)
# classifier loss
scce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Loss for real
real_dloss = (bce(tf.ones_like(rValid), rValid) + scce(label, rLabel))#/2
# Loss for fake
fake_dloss = (bce(tf.zeros_like(fValid), fValid) + scce(label, fLabel))#/2
# Total discriminator loss
d_loss = (real_dloss + fake_dloss)# / 2
return d_loss
def generator_loss(disc_generated_output, gen_output, target):
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
LAMBDA = 100
# mean absolute error
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
total_gen_loss = gan_loss + (LAMBDA * l1_loss)
return total_gen_loss这是我的训练步骤:
def train_step(img1, img2, label, generator,discriminator,generator_optimizer,discriminator_optimizer):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
fImg = generator([img1, label], training=True)
rValid, rLabel = discriminator(img2, training=True)
fValid, fLabel = discriminator(fImg, training=True)
disc_loss = discLoss(rValid, rLabel, fValid, fLabel)
gen_loss = generator_loss(fValid, fImg, img2)
# genLoss(label, rValid, rLabel, fValid, fLabel)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
return tf.math.reduce_mean(gen_loss).numpy(), disc_loss.numpy()发布于 2020-11-08 04:50:27
这个损失太高了。你需要注意G和D的学习速度是一致的。请访问此问题并提供相关链接:How to balance the generator and the discriminator performances in a GAN?
https://stackoverflow.com/questions/64719046
复制相似问题