首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Tensorflow.js中训练卷积生成对抗网络?

如何在Tensorflow.js中训练卷积生成对抗网络?
EN

Stack Overflow用户
提问于 2020-04-29 08:38:23
回答 1查看 89关注 0票数 1

我在https://www.tensorflow.org/tutorials/generative/dcgan上学习教程。

虽然本教程是用python编写的,但我尝试在node.js上使用tensorflow.js来实现它。

我已经知道如何翻译大多数使用的方法,除了实际设置以下训练步骤的时候。

代码语言:javascript
复制
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

显然,并不是这里的所有东西都可以翻译成tensorflow.js。

到目前为止,我还不知道如何获得梯度并将它们应用于优化器。

我曾尝试使用tf.gradtf.grads函数,但无济于事。

这是我到目前为止所知道的:

代码语言:javascript
复制
function trainStep(images) {
    const noise = tf.randomNormal([BATCH_SIZE, noiseDim]);

    const generated = gen.apply(noise, { training: true });
    const realOut = dis.apply(images, { training: true });
    const genOut = dis.apply(generated, { training: true });

    const genLoss = generator.loss(genOut);
    const disLoss = discriminator.loss(realOut, genOut);

    // now what?
}

在tensorflow.js中,有没有比指南更好的方法呢?

如果任何人有任何资源为我指明正确的方向,我将不胜感激。

EN

回答 1

Stack Overflow用户

发布于 2020-05-01 04:24:14

试试TensorFlow.js的官方codelab:

https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html

这是针对MNIST的,但是一旦您了解了这一点,您就可以将其应用于您自己的数据集。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61492321

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档