我绝对是Python的初学者。最近,我将MNIST手写数字数据库加载到一个生成对抗网络中。程序运行良好,但我想知道如何修改下面的代码,以便加载我自己的训练数据,即JPG文件夹,而不是MNIST数据库。有没有一种简单的方法可以用这段代码做到这一点?
我知道我需要将图像转换为MNIST格式,但除此之外,我不知道我必须包括和/或编辑哪些行才能加载文件夹。
谢谢你的帮助!
import os
import numpy as np
import matplotlib
matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
from tqdm import tqdm
from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers
os.environ["KERAS_BACKEND"] = "tensorflow"
np.random.seed(10)
random_dim = 100
def load_minst_data():
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5)/127.5
x_train = x_train.reshape(60000, 784)
return (x_train, y_train, x_test, y_test)
def get_optimizer():
return Adam(lr=0.0002, beta_1=0.5)
def get_generator(optimizer):
generator = Sequential()
generator.add(Dense(256, input_dim=random_dim,
kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=optimizer)
return generator
def get_discriminator(optimizer):
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784,
kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
return discriminator
def get_gan_network(discriminator, random_dim, generator, optimizer):
discriminator.trainable = False
gan_input = Input(shape=(random_dim,))
x = generator(gan_input)
gan_output = discriminator(x)
gan = Model(inputs=gan_input, outputs=gan_output)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
return gan
def plot_generated_images(epoch, generator, examples=100, dim=(10, 10),
figsize=(10, 10)):
noise = np.random.normal(0, 1, size=[examples, random_dim])
generated_images = generator.predict(noise)
generated_images = generated_images.reshape(examples, 28, 28)
plt.figure(figsize=figsize)
for i in range(generated_images.shape[0]):
plt.subplot(dim[0], dim[1], i+1)
plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
plt.axis('off')
plt.tight_layout()
plt.savefig('gan_generated_image_epoch_%d.png' % epoch)
def train(epochs=1, batch_size=128):
x_train, y_train, x_test, y_test = load_minst_data()
batch_count = x_train.shape[0] // batch_size
adam = get_optimizer()
generator = get_generator(adam)
discriminator = get_discriminator(adam)
gan = get_gan_network(discriminator, random_dim, generator, adam)
for e in range(1, epochs+1):
print ('-'*15, 'Epoch %d' % e, '-'*15)
for _ in tqdm(range(batch_count)):
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
generated_images = generator.predict(noise)
X = np.concatenate([image_batch, generated_images])
y_dis = np.zeros(2*batch_size)
y_dis[:batch_size] = 0.9
discriminator.trainable = True
discriminator.train_on_batch(X, y_dis)
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
y_gen = np.ones(batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y_gen)
if e == 1 or e % 20 == 0:
plot_generated_images(e, generator)
if __name__ == '__main__':
train(400, 128)发布于 2019-03-31 02:57:42
MNIST是一个数据集而不是一种格式。
以下代码行是代码作者加载数据集的位置:
x_train, y_train, x_test, y_test = load_minst_data()在:
def train(epochs=1, batch_size=128):
x_train, y_train, x_test, y_test = load_minst_data()
...这将调用以下函数:
def load_minst_data():
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5)/127.5
x_train = x_train.reshape(60000, 784)
return (x_train, y_train, x_test, y_test)你可以修改它来加载你的数据集。
此外,由于代码期望x_train, y_train, x_test, y_test的四个变量分别位于:
(x_train, y_train), (x_test, y_test) = mnist.load_data()我建议使用train_test_split。该函数将帮助您将数据拆分为上述变量。
下面这行代码:
x_train = (x_train.astype(np.float32) - 127.5)/127.5Normalizes训练样本。
以及在以下位置:
x_train = x_train.reshape(60000, 784)作者将60000样本展平为一个大小为784的向量,以便将它们提供给模型。
因为MNIST最初是28x28像素,所以有可能重塑成784。
您还需要在以下位置修改数据形状或更改input_dim:
discriminator.add(Dense(1024, input_dim=784,kernel_initializer=initializers.RandomNormal(stddev=0.02)))可能还有每一种784的存在。
https://stackoverflow.com/questions/55434694
复制相似问题