首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >具有相同TensorFlow网络的两个版本,具有不同的权重并从一个版本更新另一个版本

具有相同TensorFlow网络的两个版本,具有不同的权重并从一个版本更新另一个版本
EN

Stack Overflow用户
提问于 2018-02-02 07:08:32
回答 1查看 909关注 0票数 2

我正在尝试实现DeepMind用来训练AI玩Atari游戏的深度Q学习程序。他们使用并在多个教程中提到的功能之一是拥有两个版本的神经网络;一个在你循环小批量训练数据(称为这个Q)时更新,另一个在你这样做时调用,以帮助构建训练数据(Q')。然后周期性地(比如每10k个数据点)将Q‘中的权重设置为Q的当前值。

我的问题是,在TensorFlow中做这件事的最好方法是什么?同时存储两个相同的体系结构网络,并定期更新一个网络的权重。我的当前网络如下所示,当前仅使用默认图形和交互式会话。

代码语言:javascript
复制
sess = tf.InteractiveSession()

x = tf.placeholder(tf.float32, shape=[None, height, width, m])
y_ = tf.placeholder(tf.float32, shape=[None, env.action_space.n])

W_conv1 = weight_variable([8, 8, 4, 32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x, W_conv1, 4, 4) + b_conv1)

W_conv2 = weight_variable([4, 4, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_conv1, W_conv2, 2, 2) + b_conv2)

W_conv3 = weight_variable([3, 3, 64, 64])
b_conv3 = bias_variable([64])
h_conv3 = tf.nn.relu(conv2d(h_conv2, W_conv3, 1, 1) + b_conv3)

# Flattern conv to dense
flat_input_size = 14*10*64
h_conv3_reshape = tf.reshape(h_conv3, [-1, flat_input_size])

# Dense layers
W_fc1 = weight_variable([flat_input_size, 512])
b_fc1 = bias_variable([512])
h_fc1 = tf.nn.relu(tf.matmul(h_conv3_reshape, W_fc1) + b_fc1)

W_fc2 = weight_variable([512, env.action_space.n])
b_fc2 = bias_variable([env.action_space.n])
y_conv = tf.matmul(h_fc1, W_fc2) + b_fc2

accuracy = tf.squared_difference(y_, y_conv)
loss = tf.reduce_mean(accuracy)
optimizer = tf.train.AdamOptimizer(0.0001).minimize(loss)

tf.global_variables_initializer().run()
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-02-04 04:43:52

这里有一个安排的方法。首先,为每个网络制作一个单独的图,以便在不同的会话中并行运行它们:

代码语言:javascript
复制
graph1 = tf.Graph()
with graph1.as_default():
  model1 = build_model()

graph2 = tf.Graph()
with graph2.as_default():
  model2 = build_model()

..。其中build_model()定义了所有占位符、变量和训练操作。两个模型应该对变量使用相同的名称,这将允许它们轻松地交换状态。

每个网络都可以使用目标的另一个网络的快照进行训练(最新的或以前的最好的,取决于您)。每个网络定期通过tf.Saver()保存到磁盘,并使用其中一个网络的权重进行恢复。例如,此代码会将权重从第二个网络加载到第一个图中:

代码语言:javascript
复制
with tf.Session(graph=graph1) as sess:
  saver = tf.train.import_meta_graph('/tmp/model-2/network.meta')
  saver.restore(sess, '/tmp/model-2/network')
  ... continue training

下面是模型的保存方式:

代码语言:javascript
复制
with tf.Session(graph=graph1) as sess:
  ... do some training
  save_path = saver.save(sess, '/tmp/model-1/network')

有关在this question中保存和恢复的更多信息。您可以在同一会话中执行此操作,也可以启动一个新会话。

实际上,您甚至可以尝试在两个网络的磁盘上使用相同的位置,以便从同一文件进行保存和恢复。但这将迫使您拥有最新的快照,而以前的方法更灵活。

需要注意的一件事是会话的使用:为graph1创建的会话只能计算来自graph1的张量和操作。示例:

代码语言:javascript
复制
def build_model():
  x = tf.placeholder(tf.float32, name='x')
  y = tf.placeholder(tf.float32, name='y')
  z = x + y
  return x, y, z

graph1 = tf.Graph()
with graph1.as_default():
  x1, y1 ,z1 = build_model()

graph2 = tf.Graph()
with graph2.as_default():
  x2, y2, z2 = build_model()

with tf.Session(graph=graph1) as sess1:
  with tf.Session(graph=graph2) as sess2:
    # Good
    print(sess1.run(z1, feed_dict={x1: 1, y1: 2}))  # 3.0
    print(sess2.run(z2, feed_dict={x2: 3, y2: 1}))  # 4.0

    # BAD! Wrong graph
    # print(sess1.run(z2, feed_dict={x2: 3, y2: 1}))
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48573243

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档