首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow:恢复多个图

TensorFlow:恢复多个图
EN

Stack Overflow用户
提问于 2016-10-18 02:50:29
回答 1查看 1.4K关注 0票数 3

假设我们有两个TensorFlow计算图,G1G2,其保存的权重为W1W2。假设我们简单地通过构造GG2来构建一个新的图G2。如何恢复这个新图的W1W2G

举一个简单的例子:

代码语言:javascript
复制
import tensorflow as tf

V1 = tf.Variable(tf.zeros([1]))
saver_1 = tf.train.Saver()
V2 = tf.Variable(tf.zeros([1]))
saver_2 = tf.train.Saver()

sess = tf.Session()
saver_1.restore(sess, 'W1')
saver_2.restore(sess, 'W2')

在本例中,saver_1成功地还原了相应的V1,但是saver_2NotFoundError中失败了。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-10-27 05:07:50

您可能可以使用两个保护程序,每个保护程序只查找其中一个变量。如果您只使用tf.train.Saver(),我认为它将查找您定义的所有变量。您可以使用tf.train.Saver([v1, ...])为它提供要查找的变量列表。有关更多信息,您可以在这里阅读tf.train.Saver构造函数:ops.html#Saver

下面是一个简单的工作示例。假设您在一个文件"save_vars.py“中进行计算,并且它有以下代码:

代码语言:javascript
复制
import tensorflow as tf

# Graph 1 - set v1 to have value [1.0]
g1 = tf.Graph()
with g1.as_default():
    v1 = tf.Variable(tf.zeros([1]), name="v1")
    assign1 = v1.assign(tf.constant([1.0]))
    init1 = tf.initialize_all_variables()
    save1 = tf.train.Saver()

# Graph 2 - set v2 to have value [2.0]
g2 = tf.Graph()
with g2.as_default():
    v2 = tf.Variable(tf.zeros([1]), name="v2")
    assign2 = v2.assign(tf.constant([2.0]))
    init2 = tf.initialize_all_variables()
    save2 = tf.train.Saver()

# Do the computation for graph 1 and save
sess1 = tf.Session(graph=g1)
sess1.run(init1)
print sess1.run(assign1)
save1.save(sess1, "tmp/v1.ckpt")

# Do the computation for graph 2 and save
sess2 = tf.Session(graph=g2)
sess2.run(init2)
print sess2.run(assign2)
save2.save(sess2, "tmp/v2.ckpt")

如果您确保您有一个tmp目录并运行python save_vars.py,您将得到保存的检查点文件。

现在,您可以用以下代码使用名为"restore_vars.py“的文件进行还原:

代码语言:javascript
复制
import tensorflow as tf

# The variables v1 and v2 that we want to restore
v1 = tf.Variable(tf.zeros([1]), name="v1")
v2 = tf.Variable(tf.zeros([1]), name="v2")

# saver1 will only look for v1
saver1 = tf.train.Saver([v1])
# saver2 will only look for v2
saver2 = tf.train.Saver([v2])
with tf.Session() as sess:
    saver1.restore(sess, "tmp/v1.ckpt")
    saver2.restore(sess, "tmp/v2.ckpt")
    print sess.run(v1)
    print sess.run(v2)

当您运行python restore_vars.py时,输出应该是

代码语言:javascript
复制
[1.]
[2.]

(至少在我的电脑上,这是输出)。如果有什么不清楚的地方,可以随时发表评论。

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/40098743

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档