首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >生成式对抗网络中如何利用鉴别器的输出来训练生成器

生成式对抗网络中如何利用鉴别器的输出来训练生成器
EN

Stack Overflow用户
提问于 2017-06-24 03:45:46
回答 1查看 830关注 0票数 0

最近我了解了Generative Adversarial Networks

为了训练生成器,我不知怎么搞错了它是如何学习的。Here是GAN的实现:

代码语言:javascript
复制
`# train generator
            z = Variable(xp.random.uniform(-1, 1, (batchsize, nz), dtype=np.float32))
            x = gen(z)
            yl = dis(x)
            L_gen = F.softmax_cross_entropy(yl, Variable(xp.zeros(batchsize, dtype=np.int32)))
            L_dis = F.softmax_cross_entropy(yl, Variable(xp.ones(batchsize, dtype=np.int32)))

        # train discriminator

        x2 = Variable(cuda.to_gpu(x2))
        yl2 = dis(x2)
        L_dis += F.softmax_cross_entropy(yl2, Variable(xp.zeros(batchsize, dtype=np.int32)))

        #print "forward done"

        o_gen.zero_grads()
        L_gen.backward()
        o_gen.update()

        o_dis.zero_grads()
        L_dis.backward()
        o_dis.update()`

因此,它计算了在论文中提到的发电机的损失。但是,它根据判别器输出调用Generator反向函数。鉴别器输出只是一个数字(不是数组)。

例如,如果输出是64*64,那么我们将其与64*64图像进行比较,然后计算损失并进行反向传播。

我认为如果我们连接两个网络(连接生成器和鉴别器),然后调用反向传播,但只需更新生成器参数,这是有意义的,它应该可以工作。但我在代码中看到的是完全不同的。

所以我在问这是怎么可能的?

谢谢

EN

回答 1

Stack Overflow用户

发布于 2018-12-25 18:09:22

你说‘但是,它根据鉴别器输出调用Generator反向函数。鉴别器输出只是一个数字(不是数组)’,而损失始终是标量值。当我们计算两个图像的均方误差时,它也是一个标量值。

L_adversarial =Elog(D(X))+Elog(1−D(G(Z)

X来自真实的数据分布

Z是由生成器转换的潜在数据分布

回到你的实际问题,鉴别器网络在最后一层有一个sigmoid激活函数,这意味着它的输出范围是0,1。鉴别器试图通过最大化添加到损失函数中的两个项来最大化这种损失。第一项的最大值为0,当D(x)为1时,第二项的最大值也为0时;当1-D(G(z))为1时,表示D(G(z))为0。所以判别器试图做一个二进制分类,最大化这个损失函数,当它被馈送x(真实数据)时,它试图输出1,当它被馈送G(Z)(生成的假数据)时,它试图输出0。但是Generator试图最小化这种损失,换句话说,它试图通过生成与真实样本相似的假样本来欺骗鉴别器。随着时间的推移,生成器和鉴别器都变得越来越好。这就是GAN背后的直觉。

代码在pytorch中

代码语言:javascript
复制
bce_loss = nn.BCELoss() #bce_loss = -ylog(y_hat)-(1-y)log(1-y_hat)[similar to L_adversarial]

Discriminator = ..... #some network   
Generator = ..... #some network

optimizer_generator = ....... #some optimizer for generator network    
optimizer_discriminator = ....... #some optimizer for discriminator network       

z = ...... #some latent data distribution that is transformed by the generator
real = ..... #real data distribution

#####################
#Update Discriminator
#####################
fake = Generator(z)
fake_prediction = Discriminator(fake)
real_prediction = Discriminator(real)
discriminator_loss = bce_loss(fake_prediction,torch.zeros(batch_size))+bce_loss(real_prediction,torch.ones(batch_size))
discriminator_loss.backward()
optimizer_discriminator.step()

#################
#Update Generator
#################
fake = Generator(z)
fake_prediction = Discriminator(fake)
generator_loss = bce_loss(fake_prediction,torch.ones(batch_size))
generator_loss.backward()
optimizer_generator.step()
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/44728913

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档