我一直在试着了解一些关于丹索尔·弗洛的知识。资料来源:gans.好消息是,我成功地完成了import_example.py,使用了经过训练的模型中的正确输出图像样本:Karras2018iclr-卤代克-1024x1024.pkl。现在,我使用dataset_tool.py (create_from_images)创建了自己的数据集( dataset,create_from_images),其中400个图像位于1024x1024。我修改了config.py
致:'C:/Users/Anaconda3/envs/ProgressiveGAN/source/datasets/MYIMAGES/‘result_dir = 'C:/Users/Anaconda3/envs/ProgressiveGAN/source/results/’
并创建了一个新的数据集。
然而,我得到了:
(ProgressiveGAN) C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018>python train.py
Initializing TensorFlow...
Running train.train_progressive_gan()...
Streaming data using dataset.TFRecordDataset...
Traceback (most recent call last):
File "train.py", line 285, in <module>
tfutil.call_func_by_name(**config.train)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\tfutil.py", line 236, in call_func_by_name
return import_obj(func)(*args, **kwargs)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\train.py", line 151, in train_progressive_gan
training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\dataset.py", line 234, in load_dataset
dataset = tfutil.import_obj(class_name)(**adjusted_kwargs)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\dataset.py", line 67, in __init__
assert os.path.isdir(self.tfrecord_dir)
AssertionErrorconfig.py代码
# Paths.
data_dir = 'C:/Users/Anaconda3/envs/ProgressiveGAN/source/datasets/MYIMAGES'
result_dir = 'C:/Users/Anaconda3/envs/ProgressiveGAN/source/results/'
# Official training configs, targeted mainly for CelebA-HQ.
# To run, comment/uncomment the lines as appropriate and launch train.py.
desc = 'pgan' # Description string included in result subdir name.
random_seed = 1000 # Global random seed.
dataset = EasyDict() # Options for dataset.load_dataset().
train = EasyDict(func='train.train_progressive_gan') # Options for main training func.
G = EasyDict(func='networks.G_paper') # Options for generator network.
D = EasyDict(func='networks.D_paper') # Options for discriminator network.
G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
G_loss = EasyDict(func='loss.G_wgan_acgan') # Options for generator loss.
D_loss = EasyDict(func='loss.D_wgangp_acgan') # Options for discriminator loss.
sched = EasyDict() # Options for train.TrainingSchedule.
grid = EasyDict(size='1080p', layout='random') # Options for train.setup_snapshot_image_grid().
# Dataset (choose one).
desc += '-MYIMAGES'; dataset = EasyDict(tfrecord_dir='MYIMAGES'); train.mirror_augment = True我想训练我自己的400张1024x1024的图片。
发布于 2019-11-18 20:46:09
我最近遇到了这个问题。这是一个非常简单的解决办法。
您所要做的就是将您的tfrecord_dir放入目录数据集,即tfrecord/dir集。
https://stackoverflow.com/questions/57210881
复制相似问题