首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我应该如何将我自己的训练数据加载到这个生成对抗网络中?

我应该如何将我自己的训练数据加载到这个生成对抗网络中?
EN

Stack Overflow用户
提问于 2019-03-31 02:52:35
回答 1查看 475关注 0票数 0

我绝对是Python的初学者。最近,我将MNIST手写数字数据库加载到一个生成对抗网络中。程序运行良好,但我想知道如何修改下面的代码,以便加载我自己的训练数据,即JPG文件夹,而不是MNIST数据库。有没有一种简单的方法可以用这段代码做到这一点?

我知道我需要将图像转换为MNIST格式,但除此之外,我不知道我必须包括和/或编辑哪些行才能加载文件夹。

谢谢你的帮助!

代码语言:javascript
复制
import os
import numpy as np
import matplotlib
matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
from tqdm import tqdm
from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers
os.environ["KERAS_BACKEND"] = "tensorflow"
np.random.seed(10)
random_dim = 100
def load_minst_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train, x_test, y_test)
def get_optimizer():
        return Adam(lr=0.0002, beta_1=0.5)
def get_generator(optimizer):
        generator = Sequential()
        generator.add(Dense(256, input_dim=random_dim, 
kernel_initializer=initializers.RandomNormal(stddev=0.02)))
        generator.add(LeakyReLU(0.2))
        generator.add(Dense(512))
        generator.add(LeakyReLU(0.2))
        generator.add(Dense(1024))
        generator.add(LeakyReLU(0.2))
        generator.add(Dense(784, activation='tanh'))
        generator.compile(loss='binary_crossentropy', optimizer=optimizer)
        return generator
def get_discriminator(optimizer):
        discriminator = Sequential()
        discriminator.add(Dense(1024, input_dim=784, 
kernel_initializer=initializers.RandomNormal(stddev=0.02)))
        discriminator.add(LeakyReLU(0.2))
        discriminator.add(Dropout(0.3))
        discriminator.add(Dense(512))
        discriminator.add(LeakyReLU(0.2))
        discriminator.add(Dropout(0.3))
        discriminator.add(Dense(256))
        discriminator.add(LeakyReLU(0.2))
        discriminator.add(Dropout(0.3))
        discriminator.add(Dense(1, activation='sigmoid'))
        discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
       return discriminator
def get_gan_network(discriminator, random_dim, generator, optimizer):
        discriminator.trainable = False
        gan_input = Input(shape=(random_dim,))
        x = generator(gan_input)
        gan_output = discriminator(x)
        gan = Model(inputs=gan_input, outputs=gan_output)
        gan.compile(loss='binary_crossentropy', optimizer=optimizer)
        return gan
def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), 
figsize=(10, 10)):
        noise = np.random.normal(0, 1, size=[examples, random_dim])
        generated_images = generator.predict(noise)
        generated_images = generated_images.reshape(examples, 28, 28)
        plt.figure(figsize=figsize)
        for i in range(generated_images.shape[0]):
                plt.subplot(dim[0], dim[1], i+1)
                plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
                plt.axis('off')
        plt.tight_layout()
        plt.savefig('gan_generated_image_epoch_%d.png' % epoch)
def train(epochs=1, batch_size=128):
         x_train, y_train, x_test, y_test = load_minst_data()
         batch_count = x_train.shape[0] // batch_size
         adam = get_optimizer()
         generator = get_generator(adam)
         discriminator = get_discriminator(adam)
         gan = get_gan_network(discriminator, random_dim, generator, adam)
         for e in range(1, epochs+1):
                 print ('-'*15, 'Epoch %d' % e, '-'*15)
                 for _ in tqdm(range(batch_count)):
                         noise = np.random.normal(0, 1, size=[batch_size, random_dim])
                         image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
                         generated_images = generator.predict(noise)
                         X = np.concatenate([image_batch, generated_images])
                         y_dis = np.zeros(2*batch_size)
                         y_dis[:batch_size] = 0.9
                         discriminator.trainable = True
                         discriminator.train_on_batch(X, y_dis)
                         noise = np.random.normal(0, 1, size=[batch_size, random_dim])
                         y_gen = np.ones(batch_size)
                         discriminator.trainable = False
                         gan.train_on_batch(noise, y_gen)
                 if e == 1 or e % 20 == 0:
                         plot_generated_images(e, generator)
if __name__ == '__main__':
         train(400, 128)
EN

回答 1

Stack Overflow用户

发布于 2019-03-31 02:57:42

MNIST是一个数据集而不是一种格式。

以下代码行是代码作者加载数据集的位置:

代码语言:javascript
复制
x_train, y_train, x_test, y_test = load_minst_data()

在:

代码语言:javascript
复制
def train(epochs=1, batch_size=128):
         x_train, y_train, x_test, y_test = load_minst_data()
         ...

这将调用以下函数:

代码语言:javascript
复制
def load_minst_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train, x_test, y_test)

你可以修改它来加载你的数据集。

此外,由于代码期望x_train, y_train, x_test, y_test的四个变量分别位于:

代码语言:javascript
复制
(x_train, y_train), (x_test, y_test) = mnist.load_data()

我建议使用train_test_split。该函数将帮助您将数据拆分为上述变量。

下面这行代码:

代码语言:javascript
复制
x_train = (x_train.astype(np.float32) - 127.5)/127.5

Normalizes训练样本。

以及在以下位置:

代码语言:javascript
复制
x_train = x_train.reshape(60000, 784)

作者将60000样本展平为一个大小为784的向量,以便将它们提供给模型。

因为MNIST最初是28x28像素,所以有可能重塑成784

您还需要在以下位置修改数据形状或更改input_dim

代码语言:javascript
复制
discriminator.add(Dense(1024, input_dim=784,kernel_initializer=initializers.RandomNormal(stddev=0.02)))

可能还有每一种784的存在。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55434694

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档