首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow -恢复模型的平均模型权重

Tensorflow -恢复模型的平均模型权重
EN

Stack Overflow用户
提问于 2018-05-16 14:33:05
回答 2查看 2.4K关注 0票数 2

考虑到我在同一数据上训练了几个不同的模型,而我训练的所有神经网络都有相同的体系结构,我想知道是否有可能恢复这些模型,平均它们的权重,并用平均值初始化我的权重。

这是一个图形外观的例子。基本上,我需要的是一个平均的重量,我要负荷。

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

#init model1 weights
weights = {
    'w1': tf.Variable(),
    'w2': tf.Variable()
}
# init model1 biases
biases = {
    'b1': tf.Variable(),
    'b2': tf.Variable()
}
#init model2 weights
weights2 = {
    'w1': tf.Variable(),
    'w2': tf.Variable()
}
# init model2 biases
biases2 = {
    'b1': tf.Variable(),
    'b2': tf.Variable(),
}

# this the average I want to create
w = {
    'w1': tf.Variable(
        tf.add(weights["w1"], weights2["w1"])/2
    ),
    'w2': tf.Variable(
        tf.add(weights["w2"], weights2["w2"])/2
    ),
    'w3': tf.Variable(
        tf.add(weights["w3"], weights2["w3"])/2
    )
}
# init biases
b = {
    'b1': tf.Variable(
        tf.add(biases["b1"], biases2["b1"])/2
    ),
    'b2': tf.Variable(
        tf.add(biases["b2"], biases2["b2"])/2
    ),
    'b3': tf.Variable(
        tf.add(biases["b3"], biases2["b3"])/2
    )
}

weights_saver = tf.train.Saver({
    'w1' : weights['w1'],
    'w2' : weights['w2'],
    'b1' : biases['b1'],
    'b2' : biases['b2']
    })
weights_saver2 = tf.train.Saver({
    'w1' : weights2['w1'],
    'w2' : weights2['w2'],
    'b1' : biases2['b1'],
    'b2' : biases2['b2']
    })

这就是我在tf会议上想要得到的。C包含我想用来开始训练的权重。

代码语言:javascript
复制
# Create a session for running operations in the Graph.
init_op = tf.global_variables_initializer()
init_op2 = tf.local_variables_initializer()

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    # Initialize the variables (like the epoch counter).
    sess.run(init_op)
    sess.run(init_op2)
    weights_saver.restore(
        sess,
        'my_model1/model_weights.ckpt'
    )
    weights_saver2.restore(
        sess,
        'my_model2/model_weights.ckpt'
    )
    a = sess.run(weights)
    b = sess.run(weights2)
    c = sess.run(w)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-05-16 14:45:01

首先,假设模型结构完全相同(相同的层数、相同的节点/层数)。如果不是,您将遇到映射变量的问题(在一个模型中有变量,而在另一个模型中则没有。

你想做的是有三个疗程。前两个你从检查点加载,最后一个将保持平均值。您希望这样做,因为每个会话都将包含变量值的一个版本。

加载模型之后,使用tf.trainable_variables()获取模型中所有变量的列表。您可以将其传递给sess.run,以获得作为numpy数组的变量。计算平均值之后,使用tf.assign创建操作来更改变量。您也可以使用列表来更改初始化器,但这意味着传递给模型(并不总是一个选项)。

大致如下:

代码语言:javascript
复制
graph = tf.Graph()
session1 = tf.Session()
session2 = tf.Session()
session3 = tf.Session()

# Omitted code: Restore session1 and session2.
# Optionally initialize session3.

all_vars = tf.trainable_variables()
values1 = session1.run(all_vars)
values2 = session2.run(all_vars)

all_assign = []
for var, val1, val2 in zip(all_vars, values1, values2):
  all_assign.append(tf.assign(var, tf.reduce_mean([val1,val2], axis=0)))

session3.run(all_assign)

# Do whatever you want with session 3.
票数 5
EN

Stack Overflow用户

发布于 2020-05-12 12:28:12

通过使用tf.train.list_variablestf.train.load_checkpoint,您可以对任何检查点、任何模型以非常通用的方式实现这一点。

您可以找到一个示例这里

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50373678

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档