首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch中生成对抗网络(GAN)的训练生成器

PyTorch中生成对抗网络(GAN)的训练生成器
EN

Stack Overflow用户
提问于 2020-06-06 23:45:27
回答 1查看 418关注 0票数 1

我正致力于在PyTorch 1.5.0中实现一个生成对抗网络(GAN)。

为了计算生成器的损失,我计算了鉴别器错误分类全真小批次和全(生成器生成的)假小批次的负概率。然后,我按顺序向后传播这两部分,最后应用阶跃函数。

计算和反向传播作为所生成的假数据的错误分类的函数的部分损失似乎是直接的,因为在该损失项的反向传播期间,反向路径通过首先产生假数据的生成器。

然而,所有真实数据小批量的分类并不涉及通过生成器传递数据。因此,我想知道下面的代码片段是否仍然会为生成器计算梯度,或者它是否根本不会计算任何梯度(因为后向路径不会通过生成器,并且在更新生成器时鉴别器处于eval模式)?

代码语言:javascript
复制
# Update generator #
net.generator.train()
net.discriminator.eval()
net.generator.zero_grad()

# All-real minibatch
x_real = get_all_real_minibatch()
y_true = torch.full((batch_size,), label_fake).long()  # Pretend true targets were fake
y_pred = net.discriminator(x_real)  # Produces softmax probability distribution over (0=label_fake,1=label_real)

loss_real = NLLLoss(torch.log(y_pred), y_true) 
loss_real.backward()
optimizer_generator.step()

如果这不能像预期的那样工作,我怎么能让它工作呢?提前感谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-07 22:16:18

不会将梯度传播到生成器,因为没有使用生成器的任何参数执行任何计算。处于eval模式的鉴别器不会阻止梯度传播到生成器,尽管如果您使用的层在eval模式下与训练模式下的行为不同,例如dropout,则它们会略有不同。

真实图像的错误分类不是训练生成器的一部分,因为它不会从这些信息中获得任何东西。从概念上讲,生成器应该从鉴别器未能正确分类真实图像的事实中学到什么?生成器的唯一任务是创建一个假图像,以便鉴别器认为它是真实的,因此,生成器唯一相关的信息是鉴别器是否能够识别假图像。如果鉴别器确实能够识别假图像,则生成器需要调整自身以创建更有说服力的假图像。

当然,这不是二进制情况,但生成器总是试图改善假图像,以便鉴别器更确信它是真实图像。生成器的目标不是让鉴别器变得可疑(真假概率为0.5 ),而是让鉴别器完全相信它是真的,即使它是假的。这就是为什么他们是敌对的,而不是合作的。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62234151

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档