首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用Nvidia Progressive_GAN修复‘Progressive_GAN’?

如何使用Nvidia Progressive_GAN修复‘Progressive_GAN’?
EN

Stack Overflow用户
提问于 2019-07-25 22:27:40
回答 1查看 712关注 0票数 0

我一直在试着了解一些关于丹索尔·弗洛的知识。资料来源:gans.好消息是,我成功地完成了import_example.py,使用了经过训练的模型中的正确输出图像样本:Karras2018iclr-卤代克-1024x1024.pkl。现在,我使用dataset_tool.py (create_from_images)创建了自己的数据集( dataset,create_from_images),其中400个图像位于1024x1024。我修改了config.py

致:'C:/Users/Anaconda3/envs/ProgressiveGAN/source/datasets/MYIMAGES/‘result_dir = 'C:/Users/Anaconda3/envs/ProgressiveGAN/source/results/’

并创建了一个新的数据集。

然而,我得到了:

代码语言:javascript
复制
(ProgressiveGAN) C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018>python train.py
Initializing TensorFlow...
Running train.train_progressive_gan()...
Streaming data using dataset.TFRecordDataset...
Traceback (most recent call last):
File "train.py", line 285, in <module>
tfutil.call_func_by_name(**config.train)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\tfutil.py", line 236, in call_func_by_name
return import_obj(func)(*args, **kwargs)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\train.py", line 151, in train_progressive_gan
training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\dataset.py", line 234, in load_dataset
dataset = tfutil.import_obj(class_name)(**adjusted_kwargs)
File "C:\Users\Anaconda3\envs\ProgressiveGAN\source\code\2018\dataset.py", line 67, in __init__
assert os.path.isdir(self.tfrecord_dir)
AssertionError

config.py代码

代码语言:javascript
复制
# Paths.

data_dir = 'C:/Users/Anaconda3/envs/ProgressiveGAN/source/datasets/MYIMAGES'
result_dir = 'C:/Users/Anaconda3/envs/ProgressiveGAN/source/results/'

# Official training configs, targeted mainly for CelebA-HQ.
# To run, comment/uncomment the lines as appropriate and launch train.py.

desc        = 'pgan'                                        # Description     string included in result subdir name.
random_seed = 1000                                          # Global random seed.
dataset     = EasyDict()                                    # Options for dataset.load_dataset().
train       = EasyDict(func='train.train_progressive_gan')  # Options for main training func.
G           = EasyDict(func='networks.G_paper')             # Options for generator network.
D           = EasyDict(func='networks.D_paper')             # Options for discriminator network.
G_opt       = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
D_opt       = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
G_loss      = EasyDict(func='loss.G_wgan_acgan')            # Options for generator loss.
D_loss      = EasyDict(func='loss.D_wgangp_acgan')          # Options for discriminator loss.
sched       = EasyDict()                                    # Options for train.TrainingSchedule.
grid        = EasyDict(size='1080p', layout='random')       # Options for train.setup_snapshot_image_grid().

# Dataset (choose one).
desc += '-MYIMAGES';            dataset = EasyDict(tfrecord_dir='MYIMAGES'); train.mirror_augment = True

我想训练我自己的400张1024x1024的图片。

EN

回答 1

Stack Overflow用户

发布于 2019-11-18 20:46:09

我最近遇到了这个问题。这是一个非常简单的解决办法。

您所要做的就是将您的tfrecord_dir放入目录数据集,即tfrecord/dir集。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57210881

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档