首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >调试GAN收敛误差

调试GAN收敛误差
EN

Stack Overflow用户
提问于 2018-06-08 14:06:49
回答 1查看 484关注 0票数 1

建造一个GAN来生成图像。图像有3个彩色通道,96x96。

在一开始由生成器生成的图像都是黑色的,这是一个问题,因为这在统计上是非常不可能的。

同时,这两个网络的损失也没有改善。

我在下面发布了整个代码,并对其进行了评论,使其易于阅读。这是我第一次建立一个根,我是新手,所以任何帮助都是非常感谢的!

谢谢。

代码语言:javascript
复制
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.autograd import Variable

import numpy as np
import os
import cv2
from collections import deque

# training params
batch_size = 100
epochs = 1000

# loss function
loss_fx = torch.nn.BCELoss()

# processing images
X = deque()
for img in os.listdir('pokemon_images'):
    if img.endswith('.png'):
        pokemon_image = cv2.imread(r'./pokemon_images/{}'.format(img))
        if pokemon_image.shape != (96, 96, 3):
            pass
        else:
            X.append(pokemon_image)

# data loader for processing in batches
data_loader = DataLoader(X, batch_size=batch_size)

# covert output vectors to images if flag is true, else input images to vectors
def images_to_vectors(data, reverse=False):
    if reverse:
        return data.view(data.size(0), 3, 96, 96)
    else:
        return data.view(data.size(0), 27648)

# Generator model
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        n_features = 1000
        n_out = 27648

        self.model = torch.nn.Sequential(
                torch.nn.Linear(n_features, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, 512),
                torch.nn.ReLU(),
                torch.nn.Linear(512, 1024),
                torch.nn.ReLU(),
                torch.nn.Linear(1024, n_out),
                torch.nn.Tanh()
        )


    def forward(self, x):
        img = self.model(x)
        return img

    def noise(self, s):
       x = Variable(torch.randn(s, 1000))
       return x


# Discriminator model
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        n_features = 27648
        n_out = 1

        self.model = torch.nn.Sequential(
                torch.nn.Linear(n_features, 512),
                torch.nn.ReLU(),
                torch.nn.Linear(512, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, n_out),
                torch.nn.Sigmoid()
        )


    def forward(self, img):
        output = self.model(img)
        return output


# discriminator training
def train_discriminator(discriminator, optimizer, real_data, fake_data):
    N = real_data.size(0)
    optimizer.zero_grad()

    # train on real
    # get prediction
    pred_real = discriminator(real_data)

    # calculate loss
    error_real = loss_fx(pred_real, Variable(torch.ones(N, 1)))

    # calculate gradients
    error_real.backward()

    # train on fake
    # get prediction
    pred_fake = discriminator(fake_data)

    # calculate loss
    error_fake = loss_fx(pred_fake, Variable(torch.ones(N, 0)))

    # calculate gradients
    error_fake.backward()

    # update weights
    optimizer.step()

    return error_real + error_fake, pred_real, pred_fake


# generator training
def train_generator(generator, optimizer, fake_data):
    N = fake_data.size(0)

    # zero gradients
    optimizer.zero_grad()

    # get prediction
    pred = discriminator(generator(fake_data))

    # get loss
    error = loss_fx(pred, Variable(torch.ones(N, 0)))

    # compute gradients
    error.backward()

    # update weights
    optimizer.step()

    return error


# Instance of generator and discriminator
generator = Generator()
discriminator = Discriminator()

# optimizers
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)

# training loop
for epoch in range(epochs):
     for n_batch, batch in enumerate(data_loader, 0):
         N = batch.size(0)

         # Train Discriminator

         # REAL
         real_images = Variable(images_to_vectors(batch)).float()

         # FAKE
         fake_images = generator(generator.noise(N)).detach()

         # TRAIN
         d_error, d_pred_real, d_pred_fake = train_discriminator(
                 discriminator,
                 d_optimizer,
                 real_images,
                 fake_images
         )

         # Train Generator

         # generate noise
         fake_data = generator.noise(N)

         # get error based on discriminator
         g_error = train_generator(generator, g_optimizer, fake_data)

         # convert generator output to image and preprocess to show
         test_img = np.array(images_to_vectors(generator(fake_data), reverse=True).detach())
         test_img = test_img[0, :, :, :]
         test_img = test_img[..., ::-1]

         # show example of generated image
         cv2.imshow('GENERATED', test_img[0])
         if cv2.waitKey(1) & 0xFF == ord('q'):
             break

     print('EPOCH: {0}, D error: {1}, G error: {2}'.format(epoch, d_error, g_error))


cv2.destroyAllWindows()

# save weights
# torch.save('weights.pth')
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-06-08 14:37:52

如果没有数据等等,就不可能轻松地调试您的培训,但是一个可能的问题是,您的生成器的最后一层是Tanh(),这意味着-11之间的输出值。你可能想:

  1. 将您的真实图像归一化到相同的范围,例如在train_discriminator()中:训练在真实的pred_real =甄别器( real_data *2.-1.)#假定real_data在0,1
  2. 在可视化/ and之前将生成的数据重新规范化为[0, 1]。将生成器输出转换为图像并进行预处理,以显示test_img = np.array( images_to_vectors(生成器(Fake_data),reverse=True).detach() test_img = test_img0,:,,:test_img = test_img .,::-1 test_img= (test_img + 1.) / 2。
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50762466

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档