首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用AutoGrad包?

如何使用AutoGrad包?
EN

Stack Overflow用户
提问于 2018-02-02 21:07:57
回答 2查看 720关注 0票数 3

我正在试着做一件简单的事情:使用autograd来获得梯度,并做梯度下降:

代码语言:javascript
复制
import tangent

def model(x):
    return a*x + b

def loss(x,y):
    return (y-model(x))**2.0

在获得输入-输出对的损失后,我希望获得梯度wrt损失:

代码语言:javascript
复制
    l = loss(1,2)
    # grad_a = gradient of loss wrt a?
    a = a - grad_a
    b = b - grad_b

但是库教程没有展示如何获得关于a或b的渐变,即参数so,既不是autograd也不是切线。

EN

回答 2

Stack Overflow用户

发布于 2018-03-22 18:33:23

您可以使用grad函数的第二个参数来指定:

代码语言:javascript
复制
def f(x,y):
    return x*x + x*y

f_x = grad(f,0) # derivative with respect to first argument
f_y = grad(f,1) # derivative with respect to second argument

print("f(2,3)   = ", f(2.0,3.0))
print("f_x(2,3) = ", f_x(2.0,3.0)) 
print("f_y(2,3) = ", f_y(2.0,3.0))

在你的例子中,'a‘和'b’应该是损失函数的输入,损失函数将它们传递给模型以计算导数。

我刚刚回答了一个类似的问题:Partial Derivative using Autograd

票数 1
EN

Stack Overflow用户

发布于 2020-07-11 21:16:29

这里,这可能会有所帮助:

代码语言:javascript
复制
import autograd.numpy as np
from autograd import grad
def tanh(x):
  y=np.exp(-x)
  return (1.0-y)/(1.0+y)

grad_tanh = grad(tanh)

print(grad_tanh(1.0))

e=0.00001
g=(tanh(1+e)-tanh(1))/e
print(g)

输出:

代码语言:javascript
复制
0.39322386648296376
0.39322295790622513

以下是您可以创建的内容:

代码语言:javascript
复制
import autograd.numpy as np
from autograd import grad  # grad(f) returns f'

def f(x): # tanh
  y = np.exp(-x)
  return  (1.0 - y) / ( 1.0 + y)

D_f   = grad(f) # Obtain gradient function
D2_f = grad(D_f)# 2nd derivative
D3_f = grad(D2_f)# 3rd derivative
D4_f = grad(D3_f)# etc.
D5_f = grad(D4_f)
D6_f = grad(D5_f)

import  matplotlib.pyplot  as plt
plt.subplots(figsize = (9,6), dpi=153 )
x = np.linspace(-7, 7, 100)
plt.plot(x, list(map(f, x)),
         x, list(map(D_f , x)),
         x, list(map(D2_f , x)),
         x, list(map(D3_f , x)),
         x, list(map(D4_f , x)),
         x, list(map(D5_f , x)),
         x, list(map(D6_f , x)))
plt.show()

输出:

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48583421

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档