首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何用TensorFlow加载训练网络的权值

如何用TensorFlow加载训练网络的权值
EN

Stack Overflow用户
提问于 2018-06-26 13:10:34
回答 1查看 140关注 0票数 1

当尝试在经过训练的网络的多个历元上加载保存的权重时,请使用以下代码返回:

代码语言:javascript
复制
import tensorflow as tf
from returnn.Config import Config
from returnn.TFNetwork import TFNetwork

for i in range(1,11):
    modelFilePath = path/to/model/ + 'network.' + '%03d' % (i,)

    returnnConfig = Config()
    returnnConfig.load_file(path/to/configFile)
    returnnTfNetwork = TFNetwork(config=path/to/configFile, train_flag=False, eval_flag=True)

    returnnTfNetwork.construct_from_dict(returnnConfig.typed_value('network'))

    with tf.Session() as sess:
        returnnTfNetwork.load_params_from_file(modelFilePath, sess)

我得到以下错误:

代码语言:javascript
复制
Variables to restore which are not in checkpoint:
global_step_1

Variables in checkpoint which are not needed for restore:
global_step

Probably we can restore these:
(None)

Error, some entry is missing in the checkpoint
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-06-26 13:14:49

问题是每次在循环中都要重新创建TFNetwork,而且每次都会为全局步骤创建一个新变量,因为每个变量都必须有唯一的名称,因此必须调用不同的变量。

您可以在循环中这样做:

代码语言:javascript
复制
tf.reset_default_graph()
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51043758

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档