首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从Tensorflow模型恢复GAN的生成器?

如何从Tensorflow模型恢复GAN的生成器?
EN

Stack Overflow用户
提问于 2019-06-19 22:06:12
回答 1查看 1.2K关注 0票数 0

我试图用Tensorflow模型(元计时器和检查点)恢复生成对抗网络的经过训练的生成器。

我刚接触过tensorflow和python,所以我不确定我所做的是否有意义。已经尝试过从元文件中导入元计时器并从检查点恢复变量,但是我确定下一步该做什么。我的目标是从上一个检查点恢复经过训练的生成器,然后使用它从噪声输入生成新的数据。

下面是一个包含模型文件的驱动器的链接:Fq?usp=sharing

到目前为止,我已经尝试了以下方法,它似乎正在加载图形:

代码语言:javascript
复制
# import the graph from the file
imported_graph = tf.train.import_meta_graph("../../models/model-9.meta")

# list all the tensors in the graph
for tensor in tf.get_default_graph().get_operations():
    print (tensor.name)

# run the session
with tf.Session() as sess:
    # restore the saved vairable
    imported_graph.restore(sess, "../../models/model-9")

然而,我不知道下一步该做什么。是否可以使用这些文件只运行经过训练的生成器?我怎么才能拿到呢?

EN

回答 1

Stack Overflow用户

发布于 2019-11-09 09:45:45

在Tensorflow 2文档中,它们保存--生成器和鉴别器--。但是,它们并没有解释如何只还原生成器。

代码语言:javascript
复制
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

然后用

代码语言:javascript
复制
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

来自checkpoints

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56676497

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档