首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >x**2函数的随机梯度下降与梯度下降

x**2函数的随机梯度下降与梯度下降
EN

Stack Overflow用户
提问于 2021-11-04 18:22:07
回答 1查看 40关注 0票数 0

我想通过一个最简单的函数示例:y=x**2来理解SGD和GD之间的区别。

GD的功能如下:

代码语言:javascript
复制
def gradient_descent(
    gradient, start, learn_rate, n_iter=50, tolerance=1e-06
):
    vector = start
    for _ in range(n_iter):
        diff = -learn_rate * gradient(vector)
        if np.all(np.abs(diff) <= tolerance):
            break
        vector += diff
    return vector

为了找到x**2函数的最小值,我们接下来应该做(答案几乎是0,这是正确的):

代码语言:javascript
复制
gradient_descent(gradient=lambda v: 2 * x, start=10.0, learn_rate=0.2)

我如何理解,在经典的GD中,梯度是从所有数据点精确计算出来的。我在上面展示的实现中的“所有数据点”是什么?

此外,我们应该如何将该函数现代化,以便将其命名为SGD (SGD使用单个数据点来计算梯度。( gradient_descent函数中的“单点”在哪里?)

EN

回答 1

Stack Overflow用户

发布于 2021-11-05 08:16:06

在您的示例中最小化的函数不依赖于任何数据,因此说明GD和SGD之间的区别没有任何帮助。

考虑这个例子:

代码语言:javascript
复制
import numpy as np

rng = np.random.default_rng(7263)

y = rng.normal(loc=10, scale=4, size=100)


def loss(y, mean):
    return 0.5 * ((y-mean)**2).sum()


def gradient(y, mean):
    return (mean - y).sum()


def mean_gd(y, learning_rate=0.005, n_iter=15, start=0):
    """Estimate the mean of y using gradient descent"""

    mean = start

    for i in range(n_iter):

        mean -= learning_rate * gradient(y, mean)

        print(f'Iter {i} mean {mean:0.2f} loss {loss(y, mean):0.2f}')

    return mean


def mean_sgd(y, learning_rate=0.005, n_iter=15, start=0):
    """Estimate the mean of y using stochastic gradient descent"""

    mean = start

    for i in range(n_iter):

        rng.shuffle(y)
        for single_point in y:
            mean -= learning_rate * gradient(single_point, mean)

        print(f'Iter {i} mean {mean:0.2f} loss {loss(y, mean):0.2f}')

    return mean


mean_gd(y)
mean_sgd(y)
y.mean()

使用GD和SGD的两个(非常简单的)版本来估计随机样本y的均值。估计平均值是通过最小化平方loss来实现的。正如你所理解的那样,在GD中,每次更新都使用在整个数据集上计算的梯度,而在SGD中,我们一次只看一个随机点。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69844053

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档