我正在尝试用Deepmind在张量流中实现DQN文件,并且在裁剪损失函数时遇到了困难。
以下是描述损失裁剪的自然文件的摘录:
由于绝对值损失函数x对x的所有负值都有−1的导数,对x的所有正值有1的导数,因此将−1与1之间的平方误差裁剪为−1和1对应于(−1,1)区间以外的误差。这种形式的误差裁剪进一步提高了算法的稳定性。
(链接至全文:http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
到目前为止,我尝试的是使用
clipped_loss_vec = tf.clip_by_value(loss, -1, 1)为了减少损失,我计算在-1和+1之间。在这种情况下,代理没有学习适当的策略。我打印出了网络的梯度,并意识到如果损失降到-1以下,梯度就会突然变为0!
我对这种情况的推理是,在(-inf,-1) U (1,inf)中,剪裁损失是一个常数函数,这意味着这些区域的梯度为零。这反过来又确保了整个网络的梯度为零(想想看,无论我提供给网络的任何输入图像,损失都保持在-1,因为它已经被剪裁了)。
因此,我的问题有两部分:
发布于 2016-07-13 23:09:01
我怀疑它们意味着你应该将梯度剪裁为-1,1,而不是剪辑损失函数。因此,您可以像往常一样计算梯度,然后将梯度的每个组件剪辑到范围-1,1中;然后在梯度下降更新步骤中使用结果,而不是使用未经修改的梯度。
等效:定义函数f如下:
f(x) = x^2 if x in [-0.5,0.5]
f(x) = |x| - 0.25 if x < -0.5 or x > 0.5他们不使用某种形式的s^2作为损失函数(其中s是一些复杂的表达式),而是建议使用f(s)作为损失函数。这是一种平方损失和绝对值损失之间的混合:当s很小时,它的行为就像s,但是当s变大时,它的行为就像绝对值(|s|)。
注意,f的导数具有一个很好的性质,它的导数总是在-1,1的范围内。
f'(x) = 2x if x in [-0.5,0.5]
f'(x) = +1 if x > +1
f'(x) = -1 if x < -1因此,当你取这个f-based损失函数的梯度时,结果将和计算一个平方损失的梯度,然后剪裁它一样。
因此,他们所做的是有效地用Huber损失替换平方损失.函数f仅是δ= 0.5的Huber损失的两倍。
现在的要点是,以下两种选择是等价的:
前者易于实施。后者具有很好的特性(提高了稳定性;它比绝对值损失更好,因为它避免了围绕最小值的振荡)。由于两者是等价的,这意味着我们得到了一个易于实现的方案,它具有简单的平方损失和Huber损失的稳定性和鲁棒性。
发布于 2017-05-01 13:47:33
首先,本文的代码是可在线获得,这是一个非常有价值的参考。
第1部分
如果您查看代码,您将看到,在nql:getQUpdate (NeuralQLearner.lua,第180行)中,它们剪辑了Q-学习函数的错误术语:
-- delta = r + (1-terminal) * gamma * max_a Q(s2, a) - Q(s, a)
if self.clip_delta then
delta[delta:ge(self.clip_delta)] = self.clip_delta
delta[delta:le(-self.clip_delta)] = -self.clip_delta
end第2部分
在TensorFlow中,假设神经网络的最后一层称为self.output,self.actions是所有动作的一种热编码,self.q_targets_是带有目标的占位符,而self.q是计算出的Q:
# The loss function
one = tf.Variable(1.0)
delta = self.q - self.q_targets_
absolute_delta = tf.abs(delta)
delta = tf.where(
absolute_delta < one,
tf.square(delta),
tf.ones_like(delta) # squared error: (-1)^2 = 1
)或者,使用tf.clip_by_value (并具有更接近原始实现的实现):
delta = tf.clip_by_value(
self.q - self.q_targets_,
-1.0,
+1.0
) 发布于 2016-10-25 09:05:44
为了有一个“光滑”的复合损失函数,你也可以在边界值1和1处用一阶泰勒近似替换误差范围以外的平方损失1和1。在这种情况下,如果e是你的误差值,你可以在e< -1的情况下,用-2e-1代替它,如果e> 1,用2e-1代替它。
https://stackoverflow.com/questions/36462962
复制相似问题