首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch梯度下降

PyTorch梯度下降
EN

Stack Overflow用户
提问于 2018-09-06 23:24:42
回答 1查看 9.5K关注 0票数 11

我正尝试在PyTorch中手动实现梯度下降,作为学习练习。我有以下创建合成数据集的方法:

代码语言:javascript
复制
import torch
torch.manual_seed(0)
N = 100
x = torch.rand(N,1)*5
# Let the following command be the true function
y = 2.3 + 5.1*x
# Get some noisy observations
y_obs = y + 2*torch.randn(N,1)

然后,我创建我的预测函数(y_pred),如下所示。

代码语言:javascript
复制
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
y_pred = w*x+b
mse = torch.mean((y_pred-y_obs)**2)

它使用最小均方来推断权重w,b。我使用下面的块来根据梯度更新值。

代码语言:javascript
复制
gamma = 1e-2
for i in range(100):
  w = w - gamma *w.grad
  b = b - gamma *b.grad
  mse.backward()

但是,循环只在第一次迭代中工作。第二次迭代之后,w.grad 被设置为 None**.**,我很确定发生这种情况的原因是因为我将w设置为它自身的一个函数(我可能错了)。

问题是如何使用梯度信息正确地更新权重?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-09-07 08:57:48

  1. 在应用梯度下降之前,应该调用反向方法。
  2. 您需要使用新的权重来计算每次迭代的损失。
  3. 创建新的张量没有梯度带每次迭代。

下面的代码在我的计算机上运行良好,并在500次迭代训练后给出了w=5.1 & b=2.2。

代码:

代码语言:javascript
复制
import torch
torch.manual_seed(0)
N = 100
x = torch.rand(N,1)*5
# Let the following command be the true function
y = 2.3 + 5.1*x
# Get some noisy observations
y_obs = y + 0.2*torch.randn(N,1)

w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)


gamma = 0.01
for i in range(500):
    print(i)
    # use new weight to calculate loss
    y_pred = w * x + b
    mse = torch.mean((y_pred - y_obs) ** 2)

    # backward
    mse.backward()
    print('w:', w)
    print('b:', b)
    print('w.grad:', w.grad)
    print('b.grad:', b.grad)

    # gradient descent, don't track
    with torch.no_grad():
        w = w - gamma * w.grad
        b = b - gamma * b.grad
    w.requires_grad = True
    b.requires_grad = True

输出:

代码语言:javascript
复制
499
w: tensor([5.1095], requires_grad=True)
b: tensor([2.2474], requires_grad=True)
w.grad: tensor([0.0179])
b.grad: tensor([-0.0576])
票数 12
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52213282

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档