首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在GAN实现中正确设置.trainable变量

在GAN实现中正确设置.trainable变量
EN

Stack Overflow用户
提问于 2019-11-11 15:04:09
回答 1查看 338关注 0票数 2

我对在实现GAN中的.trainable语句tf.keras.model感到困惑。

给定以下代码片段(摘自这个回购):

代码语言:javascript
复制
class GAN():

    def __init__(self):

        ...

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        ...

        return Model(noise, img)

    def build_discriminator(self):

        ...

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

在定义模型self.combined时,鉴别器的权重被设置为self.discriminator.trainable = False,但从未打开。

然而,在训练循环期间,鉴别器的权重将因线路而发生变化:

代码语言:javascript
复制
# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

并将在下列期间保持不变:

代码语言:javascript
复制
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)

我没想到的。

当然,这是训练GAN的正确(迭代)方法,但我不明白为什么我们不需要通过self.discriminator.trainable = True才能对鉴别器进行一些培训。

如果有人对此有解释的话,那就太好了,我想这是一个需要理解的关键点。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-11-11 17:16:25

当您对github回购中的代码有疑问时,检查这些问题(包括打开的和关闭的)通常是个好主意。这个问题解释了为什么将标志设置为False。上面写着,

由于self.discriminator.trainable = False是在编译鉴别器之后设置的,因此它不会影响鉴别器的训练。但是,由于它是在编译组合模型之前设置的,所以在对组合模型进行训练时,将冻结鉴别器层。

同时也谈到了冻结角化层

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58803868

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档