首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用Tensorflow 2.1保存检查点的平均权重?

如何使用Tensorflow 2.1保存检查点的平均权重?
EN

Stack Overflow用户
提问于 2020-03-24 00:11:23
回答 1查看 341关注 0票数 0

我正在尝试加载检查点,并使用TF2.1保存它们的平均权重。我找到了它的TF1版本。https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py

变量"checkpoints“是检查点路径的列表

代码语言:javascript
复制
  # Read variables from all checkpoints and average them.
  logger.info("Reading variables and averaging checkpoints:")
  for c in checkpoints:
    logger.info(c)

  var_list = tf.train.list_variables(checkpoints[0])

  var_values, var_dtypes = {}, {}
  for (name, shape) in var_list:
    if not name.startswith("global_step"):
      var_values[name] = tf.zeros(shape)

  for checkpoint in checkpoints:
    reader = tf.train.load_checkpoint(checkpoint)
    for name in var_values:
      tensor = tf.convert_to_tensor(reader.get_tensor(name))

      if tensor.dtype == tf.string:
        var_values[name] = tensor
      else:
        var_values[name] = tf.cast(var_values[name], tensor.dtype)
        var_values[name] += tensor
      var_dtypes[name] = tensor.dtype
    logger.info("Read from checkpoint %s", checkpoint)

  for name in var_values:  # Average.
    if var_dtypes[name] != tf.string:
      var_values[name] /= len(checkpoints)

你能解释一下如何将平均var_values保存到检查点吗?

EN

回答 1

Stack Overflow用户

发布于 2020-03-27 20:06:25

我可以通过参考同一问题的Keras版本来保存平均检查点,因为Tensorflow 2.1遵循Keras API。

网址:Average weights in keras models

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60817342

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档