首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >详解梯度下降

详解梯度下降

原创
作者头像
用户11850891
发布2025-09-25 17:34:00
发布2025-09-25 17:34:00
2920
举报

在我研究人工智能的过程中,梯度下降法曾经困扰了我很长时间。但是作为模型训练的基础,梯度下降法又是非常重要的概念。接下来我给大家详细讲一讲我对这个概念的理解。

在开始之前,先解释一下损失函数

训练样本输入模型后产生的输出值和(该样本的)标签值相减,得到的差值就叫做损失(Loss),不同样本、不同权重(不同训练阶段)的模型产生损失有多有少,损失函数就是用数学方式表示(衡量)损失多少的函数。它是一个非负实值函数,通常用L(Y, f(x))来表示。我们一般说的损失函数衡量的是在整个训练样本空间上的整体损失。

比如均方差公式:

就是一个常见的损失函数。

想要理解梯度下降,第一步需要理解:模型的训练过程,可以表达为:对以参数为自变量的损失函数求最小值的过程(如果你已经理解,请直接跳到后面看第二步)。

现在假设模型是最简单的一元线性模型y=ax,模型训练和常规解方程的不同是模型训练是已知x、y,求a,使得L取得最小值。而梯度下降正是一种通过逐步逼近求a的方法。

为了更好理解,我们看下图:

假设样本空间中有数个样本点,从图中能看出到所有样本点平均距离最小的直线只有一条,那么直线到点的平均距离公式就是我们找的损失函数。

因为点到直线的垂直距离计算公式复杂,以点(x_1, y_1)((x_1, ax_1) 的y轴距离为替代公式,即用下图中红色虚线距离代替蓝色虚线的距离:

此时的公式为:

该公式即均方差公式。

现在我们假设样本空间只有一个样本点(1,1),此时均方差公式为(1-a)^2 。该函数的曲线如下图:

不难看出a=1时该函数取得最小值0。此时模型为y=x,即在坐标系中正好通过点(1,1):

也就是说a=1就是我们要求的参数值。

第二步:

明确了第一步的概念后,将一元参数a扩展到二元参数a,b(即模型为y=ax+b),样本点为(0,1)、(1,2)的情况下。此时MSE函数为:

(1/2) a^2+ab+b^2-2a-3b+(5/2)

该函数的几何形式如图:

该曲面是一个椭圆抛物面,顶点在(1, 1, 0),开口向上(函数表达式具体推理过程稍后我会写进附录)。由于技术原因,上面这个由程序生成的图像对抛物面的展示不明显,一个更明显的展示是:

这时我们可以看出,对于任意参数值a,b,代入MSE后会得到一个值,即这个曲面上的一个点(下图中靠近顶部的红点)。此时这个点(沿该点所在切面)可以有多个运动方向(下图中红色箭头所示),但只有一个方向(图中黑线所示)可以最快到达这个曲面的最低点。

那么在数学上怎么表示这个方向呢?这个方向就是梯度的反方向。梯度的数学表示是损失函数(MSE)在该点处对a,b的偏导数组成的向量(下图中黑色箭头所示),沿梯度的方向,y的增长速度最快:

沿梯度的反方向一次“前进”一小步,就可以用最快的速度到达这个曲面的最低点(即MSE的最小值),这种不断更新a、b的值直到到达最优解方法就被称为梯度下降法:

后记

纠正网上对于梯度下降法的一些误解。也许你们有人听过“梯度下降表示沿最陡的方向下降到(损失函数)最低点”;也许你们当中有人看到过这样的图:

这张图对于梯度下降的造成的困惑点在于:函数上的任意点只有两个运动方向(即点cita 0处红色切线所示的两个方向)如果选错了方向那永远到不了最低点。

以上就是我的一点心得,如果您有什么意见和建议欢迎提出!如果您觉得我写得还可以,欢迎点赞、收藏和分享!另外转载还请注明出处!

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 后记
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档