首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >DCGAN产生噪音但不是预期的结构,有人能帮我从DCGAN中得到正确的结果吗?

DCGAN产生噪音但不是预期的结构,有人能帮我从DCGAN中得到正确的结果吗?
EN

Stack Overflow用户
提问于 2021-10-30 02:53:02
回答 1查看 71关注 0票数 1

有人能帮我从我的DCGAN中获得正确的结果/图像吗?

从一个迭代到另一个迭代,我得到了不同的彩色(噪声)图片,但没有接近我应该得到的。我用标签喂人脸/猫/狗来训练我的生成器和鉴别器,我应该得到如下输出

要么像猫,要么像狗,要么像人脸。

对于不同的潜在向量值,我得到了不同的结果/像素,对于不同的迭代,我得到了不同的颜色,但不是在面部或动物的结构中。

我使用二进制交叉熵损失函数作为生成器,鉴别器和gan。我尝试在生成器中使用交叉熵之外的MAE,但没有得到任何不同的结果。我试着一起训练生成器和鉴别器,分别训练生成器和鉴别器,并交替使用生成器和鉴别器,但这些尝试也没有什么好结果。我运行了700多个时期(5天),每次迭代在CPU上花费超过5分钟。

代码语言:javascript
复制
def generator(latent_dim, n_classes):
    
    initializer = tf.random_normal_initializer(0., 0.021)
    in_label = Input(shape=(1,))
    li = Embedding(n_classes, 25)(in_label)
    n_nodes = 64 * 64
    
    li = Dense(n_nodes)(li)
    
    li = Reshape((64, 64, 1))(li)
    
    
    in_lat = Input(shape=(latent_dim,))
    n_nodes = 128 * 64 * 64
    
    gen = Dense(n_nodes)(in_lat)
    
    gen = LeakyReLU(alpha=0.2)(gen)
    gen = BatchNormalization(axis=-1)(gen)
    gen = Reshape((64, 64, 128))(gen)

    merge = Concatenate()([gen, li])
    gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same',kernel_initializer=initializer,use_bias=False)(merge)
    gen = LeakyReLU(alpha=0.2)(gen)
    gen = BatchNormalization(axis=-1)(gen)
    

    gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same',kernel_initializer=initializer,use_bias=False)(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    gen = BatchNormalization(axis=-1)(gen)

    
    out_layer = Conv2D(3, (7,7), activation='tanh', padding='same',kernel_initializer=initializer,use_bias=False)(gen)
    model = Model([in_lat, in_label], out_layer, name="generator")
    
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(optimizer=opt,loss='binary_crossentropy', metrics='accuracy')
    
    return model



def define_discriminator(n_classes,in_shape=(256,256,3)):
    
    initializer = tf.random_normal_initializer(0., 0.021)
    in_label = Input(shape=(1,))
    li = Embedding(n_classes, 25)(in_label)     
    n_nodes = in_shape[0] * in_shape[1]
    
    li = Dense(n_nodes)(li)

    li = Reshape((in_shape[0], in_shape[1],1))(li)
    
    in_image = Input(shape=in_shape)
    merge = Concatenate()([in_image, li])
    
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same',kernel_initializer=initializer,use_bias=False)(merge)
    fe = LeakyReLU(alpha=0.2)(fe)
    fe = BatchNormalization(axis=-1)(fe)

    fe = Conv2D(128, (3,3), strides=(2,2), padding='same',kernel_initializer=initializer,use_bias=False)(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    fe = BatchNormalization(axis=-1)(fe)
    
    
    fe = Flatten()(fe)
    fe = Dropout(0.4)(fe)

    out_layer = Dense(1, activation='sigmoid')(fe)
    model = Model([in_image, in_label], out_layer, name="discriminator")

    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(optimizer=opt, loss='binary_crossentropy', metrics=['accuracy'])
    
    return model



def define_gan(g_model, d_model):
    
    g_model.trainable = True
    d_model.trainable = False
    
    gen_noise, gen_label = g_model.input
    gen_output = g_model.output
    
    gan_output = d_model([gen_output, gen_label])
    model = Model([gen_noise, gen_label], gan_output, name="gan")
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(optimizer=opt,  loss='binary_crossentropy', metrics='accuracy')
    return model

下面是不同迭代的结果/图像。16个方块是16个不同的潜在向量种子。

EN

回答 1

Stack Overflow用户

发布于 2021-10-30 03:11:27

在没有看到任何代码的情况下,很难推测,但您可以查看this repo以供参考。

请注意以下注意事项:

另外,由于我们不知道的原因,如果你只在CPU上训练,这个笔记本中的鉴别器几乎总是无法学习。由于这个故障,GAN将很少学习如何生成草图--也就是说,它将输出仅仅是随机噪声的图像。我们已经确定了两种方法来补救这种情况:

  1. 使用图形处理器。如果你没有,请按照上面的建议使用Colab。在Colab中,您可以从菜单栏的“运行时”项中选择“更改运行时类型”,然后选择"GPU“作为您的硬件加速器。这个硬件加速器训练GAN数量级的速度比“无”或"TPU“选项和鉴别器更快(我们不知道为什么!)将训练鉴别器的优化器properly.
  2. Change。正如本笔记本的鉴别器编译步骤中的注释所指出的,从默认的RMSprop优化器切换到另一个(例如,Adam或AdaDelta)使鉴别器能够有效地学习,因此GAN生成草图。无论您仅使用CPU、GPU还是TPU,此解决方案都是有效的。(也就是说,用图形处理器训练GAN仍然比只用中央处理器或TPU快得多。)

您也可以查看此讨论:https://github.com/the-deep-learners/deep-learning-illustrated/issues/2

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69776434

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档