首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无法在tensorflow中保存/恢复变量子集

无法在tensorflow中保存/恢复变量子集
EN

Stack Overflow用户
提问于 2017-08-26 10:48:10
回答 1查看 209关注 0票数 0

我正在用ImageNet训练我的网络,这样我就可以在我的项目中使用训练过的权重的子集。

保存和恢复洞重不是一个问题,但是当我试图在没有完全连接层的情况下保存它们时,它会给我一个错误:NameError:全局名称'w1‘不是定义的。如果它对任何人都有帮助,那么存储库就在github或代码片段中:

inference.py

代码语言:javascript
复制
...
def inference(images):
    w1 = tf.get_variable('w1', shape=[5,5,3,64])   
    ...

grasp.py

代码语言:javascript
复制
def run_training():
  ...
  logits = inference(images)
  ...
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess = tf.Session()
sess.run(init_op)
saver = tf.train.Saver({'w1': w1})

回溯

代码语言:javascript
复制
Traceback (most recent call last):
  File "./grasp.py", line 130, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "/usr/local/lib/python2.7/site-
    packages/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "./grasp.py", line 83, in main
    run_training()
  File "./grasp.py", line 52, in run_training
    saver = tf.train.Saver({'w1': w1})
NameError: global name 'w1' is not defined

如果您有任何建议或需要更多的信息,请告诉我。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-08-26 11:27:31

在这里,您需要访问tf.global_variables()集中的变量。

代码语言:javascript
复制
w1 = [v for v in tf.global_variables() if v.name == 'w1:0'][0]
saver = tf.train.Saver({'w1': w1})
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45894537

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档