首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在tensorflow中加载较旧的检查点

在tensorflow中加载较旧的检查点
EN

Stack Overflow用户
提问于 2017-04-10 18:38:16
回答 1查看 711关注 0票数 0

我使用Tensorflow r0.12训练了一些模型并保存了它。后来我更新到了r1.0.1。一些模型加载没有任何问题,但是如果模型中有RNN单元,加载失败并显示Key layer-5/bidirectional_rnn/bw/multi_rnn_cell/cell_1/basic_rnn_cell/biases not found in checkpoint。同样,如果我检查model.index文件,我会看到类似的条目,例如:5/BiRNN/BW/MultiRNNCell/Cell0/BasicRNNCell/Linear/Bias

带有RNN cell的包现在是tf.contrib.rnn (在0.12中是tf.nn.rnn_cell ),所以我认为一些命名已经改变了。

问题是:有没有一种方法可以加载我的模型,重新映射它的张量并保存,以便张量名称与r1.0兼容?

附言:我也有model.meta文件,如果有帮助的话。

谢谢!

EN

回答 1

Stack Overflow用户

发布于 2017-04-10 21:31:02

如果有人遇到同样的问题,下面是我使用的解决方案。它是tensorflow.python.toolsinspect_checkpoint.py的张量打印函数的修改版本。

代码语言:javascript
复制
def resave_tensors(file_name, rename_map, dry_run=False):
    """
    Updates checkpoint by renaming tensors in it.
    :param file_name: Filename with checkpoint.
    :param rename_map: Map from old names to new ones
    :param dry_run: If True, just print new tensors.
    """
    renames_count = 0
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in sorted(var_to_shape_map):
        print("tensor_name: ", key)
        tensor_val = reader.get_tensor(key)
        print('shape: {}'.format(tensor_val.shape))
        if key in rename_map:
            renames_count += 1
            key = rename_map[key]
        tf.Variable(tensor_val, dtype=tensor_val.dtype, name=key)
    saver = tf.train.Saver()
    if not dry_run:
        with tf.Session() as session:
            session.run(tf.global_variables_initializer())
            saver.save(session, file_name)
    print('Renamed vars: {}'.format(renames_count))
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/43320993

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档