我正在尝试加载检查点,并使用TF2.1保存它们的平均权重。我找到了它的TF1版本。https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py
变量"checkpoints“是检查点路径的列表
# Read variables from all checkpoints and average them.
logger.info("Reading variables and averaging checkpoints:")
for c in checkpoints:
logger.info(c)
var_list = tf.train.list_variables(checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
if not name.startswith("global_step"):
var_values[name] = tf.zeros(shape)
for checkpoint in checkpoints:
reader = tf.train.load_checkpoint(checkpoint)
for name in var_values:
tensor = tf.convert_to_tensor(reader.get_tensor(name))
if tensor.dtype == tf.string:
var_values[name] = tensor
else:
var_values[name] = tf.cast(var_values[name], tensor.dtype)
var_values[name] += tensor
var_dtypes[name] = tensor.dtype
logger.info("Read from checkpoint %s", checkpoint)
for name in var_values: # Average.
if var_dtypes[name] != tf.string:
var_values[name] /= len(checkpoints)你能解释一下如何将平均var_values保存到检查点吗?
发布于 2020-03-27 20:06:25
我可以通过参考同一问题的Keras版本来保存平均检查点,因为Tensorflow 2.1遵循Keras API。
https://stackoverflow.com/questions/60817342
复制相似问题