首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从GAN训练发电机?

如何从GAN训练发电机?
EN

Stack Overflow用户
提问于 2019-05-11 16:44:33
回答 2查看 2.1K关注 0票数 3

在阅读了GAN教程和代码示例之后,我仍然不明白生成器是如何被训练的。假设我们有一个简单的例子:-生成器输入是噪声,输出是灰度图像10x10 -鉴别器输入是图像10x10,输出是0到1(假的或真的)。

训练鉴别器很容易以输出为假,并期望0。我们用的是真正的输出尺寸--单值。

但是训练生成器是不同的--我们取假输出(1值),并将预期输出作为一个输出。但这听起来更像是又一次培训了犯罪嫌疑人。生成器的输出是图像10x10,如何用一个单一的值来训练它?在这种情况下,反向传播是如何工作的?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-05-11 17:05:10

要训练生成器,必须在冻结鉴别器的权重的同时,反向传播整个组合模型,以便只更新生成器。

为此,我们必须计算d(g(z; θg); θd),其中θg和θd是生成器和判别器的权重。为了更新生成器,我们可以计算梯度wrt。改为∂loss(d(g(z; θg); θd)) / ∂θg,然后使用正常梯度下降更新θg。

在Keras中,这可能类似于(使用functional ):

代码语言:javascript
复制
genInput = Input(input_shape)
discriminator = ...
generator = ...

discriminator.trainable = True
discriminator.compile(...)

discriminator.trainable = False
combined = Model(genInput, discriminator(generator(genInput)))
combined.compile(...)

通过将trainable设置为False,已经编译的模型不会受到影响,只有将来编译的模型才会被冻结。因此,判别器可以作为一个独立的模型进行训练,但是冻结在组合模型中。

然后,训练你的甘:

代码语言:javascript
复制
X_real = ...
noise = ...
X_gen = generator.predict(noise)

# This will only train the discriminator
loss_real = discriminator.train_on_batch(X_real, one_out)
loss_fake = discriminator.train_on_batch(X_gen, zero_out)

d_loss = 0.5 * np.add(loss_real, loss_fake)

noise = ...
# This will only train the generator.
g_loss = self.combined.train_on_batch(noise, one_out)
票数 3
EN

Stack Overflow用户

发布于 2019-06-04 10:16:56

我想理解发电机培训过程的最好方法是修改所有的训练循环。

每一个时代:

  1. 更新鉴别器:
代码语言:javascript
复制
- forward real images mini-batch pass through the Discriminator;
- compute the Discriminator loss and calculate gradients for the backward pass;
- generate fake images mini-batch via the Generator;
- forward generated fake mini-batch pass through the Discriminator;
- compute the Discriminator loss and derive gradients for the backward pass;
- add (real mini-batch gradients, fake mini-batch gradients)
- update the Discriminator (use Adam or SGD).

  1. 更新生成器:
代码语言:javascript
复制
- flip the targets: fake images get labeled as real for the Generator. Note: this step ensures using cross-entropy minimization for the Generator. It helps overcome the problem of Generator's vanishing gradients if we continue implementation of the GAN minmax game.
- forward fake images mini-batch pass through the updated Discriminator;
- compute Generator loss based on the updated Discriminator output, e.g.:

损失函数(利用判别器估计假图像的概率,1)。

注意:这里1表示假图像的生成器标签是真实的。

代码语言:javascript
复制
- update the Generator (use Adam or SGD)

我希望这能帮到你。从训练过程中你可以看到,GAN玩家在某种程度上是“合作的”,即判别器估计数据与模型分布密度的比率,然后与生成器自由地分享这些信息。从这个角度来看,判别器更像是教师指导生成器如何提高而不是对手“(引用自I.Goodfellow教程)。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56092361

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档