首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在GAN中,有必要编译生成器吗

在GAN中,有必要编译生成器吗
EN

Stack Overflow用户
提问于 2020-06-10 20:43:15
回答 1查看 341关注 0票数 1

我一直在研究GAN,让我抓狂的是,为什么我们要编译生成器模型,即使我们编译了组合的GAN模型,为什么还要单独编译生成器。

代码语言:javascript
复制
def create_generator():
    generator = Sequential()

    generator.add(Dense(256, input_dim=noise_dim))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(512))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(1024))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(img_rows*img_cols*channels, activation='tanh'))

    generator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return generator

def create_descriminator():
    discriminator = Sequential()

    discriminator.add(Dense(1024, input_dim=img_rows*img_cols*channels))
    discriminator.add(LeakyReLU(0.2))

    discriminator.add(Dense(512))
    discriminator.add(LeakyReLU(0.2))

    discriminator.add(Dense(256))
    discriminator.add(LeakyReLU(0.2))

    discriminator.add(Dense(1, activation='sigmoid'))

    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return discriminator

discriminator = create_descriminator()
generator = create_generator()

# Make the discriminator untrainable when we are training the generator.  This doesn't effect the discriminator by itself
discriminator.trainable = False

# Link the two models to create the GAN
gan_input = Input(shape=(noise_dim,))
fake_image = generator(gan_input)

gan_output = discriminator(fake_image)

gan = Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)

正如在这段代码中可以看到的那样,生成器、鉴别器和gan (组合模型)都是经过编译的。根据我的理解,我们应该只编译鉴别器(用于训练鉴别器)和gan (组合模型,用于训练生成器),因为鉴别器权重在GAN训练期间被冻结,因此只有生成器得到训练。那么为什么要编译生成器呢?

EN

回答 1

Stack Overflow用户

发布于 2020-06-12 22:19:40

在训练过程中,generatordiscriminator有着截然相反的目标:discriminator试图区分假图像和真实图像,而生成器则试图产生看起来足够真实的图像,以欺骗鉴别器。

由于GAN由具有不同目标的两个网络组成,因此它不能像常规神经网络那样进行训练。每个训练迭代分为两个阶段:

  • 在第一阶段,我们训练鉴别器。从训练集中采样一批真实图像,并使用由生成器生成的相同数量的假图像来完成。对于假图像,标签被设置为0,对于真实图像,标签被设置为1,并且使用二进制交叉熵损失在该标签批次上训练鉴别器一步。重要的是,反向传播在这一阶段只优化鉴别器的权重。
  • 在第二阶段,我们训练生成器。我们首先使用它来产生另一批假图像,然后再次使用鉴别器来判断图像是假的还是真的。这一次,我们没有在批中添加真实的图像,并且所有标签都设置为1(真实):换句话说,我们希望生成器生成鉴别器会(错误地)认为是真实的图像!至关重要的是,在这一步中,discriminator的权重是frozen的,因此反向传播只影响生成器的权重。

接下来,我们需要编译这些模型。generator只会通过gan model来训练,所以我们根本不需要编译。重要的是,discriminator 不应在第二阶段期间进行训练,因此我们将其设为non-trainable,然后再对gan模型进行compiling

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62303892

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档