首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Tensorflow中正确地设置Adadelta算法的参数?

如何在Tensorflow中正确地设置Adadelta算法的参数?
EN

Stack Overflow用户
提问于 2016-07-28 09:36:13
回答 1查看 7.3K关注 0票数 9

我一直在使用Tensorflow进行回归。我的神经网络很小,有10个输入神经元,12个单层隐神经元和5个输出神经元。

  • 激活函数
  • 成本是产出与实际价值之间的平方距离。
  • 我的神经网络与其他优化器(如GradientDescent、Adam、Adagrad )进行了正确的训练。

然而,当我尝试使用阿德罗塔时,神经网络根本就不会训练。变量在每一步都保持不变。

我尝试了每一个初始learning_rate可能(从1.0e-6到10)和不同的权值初始化:它总是一样的。

有人对发生了什么事有一点了解吗?

非常感谢

EN

回答 1

Stack Overflow用户

发布于 2016-07-28 12:15:58

简短的回答:不要用阿德罗塔

现在很少有人使用它,你应该坚持:

  • 具有tf.train.MomentumOptimizer动量的0.9是非常标准和工作良好的。缺点是你必须找到最好的学习速度。
  • tf.train.RMSPropOptimizer:结果不太依赖于良好的学习速度。这个算法与非常相似,但在我看来效果更好。

如果您真的想使用Adadelta,请使用论文中的参数:learning_rate=1., rho=0.95, epsilon=1e-6。一个更大的epsilon在一开始会有所帮助,但要准备等待比其他优化器稍长一点,以看到收敛。

请注意,在本文中,他们甚至不使用学习速率,这与保持它与1相等。

长答案

阿德罗塔的开局很慢。的完整算法是:

问题是它们积累了更新的平方。

  • 在步骤0,这些更新的运行平均值为零,因此第一个更新将非常小。
  • 由于第一次更新非常小,在开始时更新的运行平均值将非常小,这在开始时是一种恶性循环。

我认为Adadelta在更大的网络中的表现要好于您的网络,并且经过一些迭代之后,它应该与RMSProp或Adam的性能相当。

下面是我的代码,可以使用Adadelta优化器进行一些操作:

代码语言:javascript
复制
import tensorflow as tf

v = tf.Variable(10.)
loss = v * v

optimizer = tf.train.AdadeltaOptimizer(1., 0.95, 1e-6)
train_op = optimizer.minimize(loss)

accum = optimizer.get_slot(v, "accum")  # accumulator of the square gradients
accum_update = optimizer.get_slot(v, "accum_update")  # accumulator of the square updates

sess = tf.Session()
sess.run(tf.initialize_all_variables())

for i in range(100):
    sess.run(train_op)
    print "%.3f \t %.3f \t %.6f" % tuple(sess.run([v, accum, accum_update]))

前10行:

代码语言:javascript
复制
  v       accum     accum_update
9.994    20.000      0.000001
9.988    38.975      0.000002
9.983    56.979      0.000003
9.978    74.061      0.000004
9.973    90.270      0.000005
9.968    105.648     0.000006
9.963    120.237     0.000006
9.958    134.077     0.000007
9.953    147.205     0.000008
9.948    159.658     0.000009
票数 8
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/38632536

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档