首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >通过torch.topk求出梯度

通过torch.topk求出梯度
EN

Stack Overflow用户
提问于 2021-05-17 13:21:04
回答 2查看 2.1K关注 0票数 0

我想通过torch.topk函数求出梯度。

假设输入是一个向量

然后用一个参数矩阵对其进行变换。

并选取向量的顶部k值。

结果向量通过元素方向乘法进一步变换.

最后,损失由

我想知道,与W有关的损失是可微的吗?形式上,我们能计算出下面的梯度吗?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-05-18 09:58:08

topk()操作只是一个线性变换,用来选取张量的顶部k个元素。由于这是一种W @ X或矩阵向量乘法的运算,这也是可微的.

示例:下面我以两种方式计算了流水线操作topk(Wx),并显示了两者产生的梯度是相同的。

代码语言:javascript
复制
In [1]: import torch

In [2]: x1 = torch.rand(6, requires_grad = True)

In [3]: W1 = torch.rand(6, 6, requires_grad = True)

In [4]: x1
Out[4]: tensor([0.1511, 0.5990, 0.6338, 0.5137, 0.5203, 0.0560], requires_grad=True)

In [5]: W1
Out[5]: 
tensor([[0.2541, 0.6699, 0.5311, 0.7801, 0.5042, 0.5475],
        [0.7523, 0.1331, 0.7670, 0.8132, 0.0524, 0.0269],
        [0.3974, 0.2880, 0.9142, 0.9906, 0.4401, 0.3984],
        [0.7956, 0.2071, 0.2209, 0.6192, 0.2054, 0.7693],
        [0.8587, 0.8415, 0.6033, 0.3812, 0.2498, 0.9813],
        [0.9033, 0.0417, 0.2272, 0.1576, 0.9087, 0.3284]], requires_grad=True)

In [6]: y1 = W1 @ x1

In [7]: y1
Out[7]: tensor([1.4699, 1.1260, 1.5721, 0.8523, 1.3969, 0.8776], grad_fn=<MvBackward>)

In [8]: yk, _ = torch.topk(y1, 3)

In [9]: yk
Out[9]: tensor([1.5721, 1.4699, 1.3969], grad_fn=<TopkBackward>)

In [10]: loss1 = (yk ** 2).sum()

In [11]: loss1.backward()

In [12]: W1.grad
Out[12]: 
tensor([[0.4442, 1.7609, 1.8633, 1.5102, 1.5296, 0.1646],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4751, 1.8833, 1.9928, 1.6152, 1.6359, 0.1760],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4222, 1.6734, 1.7706, 1.4352, 1.4535, 0.1564],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

现在,让我们评估相同的操作集,但是显式地使用topk()作为一个线性转换。请注意,构造的Wk矩阵通过乘法有选择地从6元素张量中选择顶k(此处3)元素。

代码语言:javascript
复制
In [13]: x2 = torch.tensor([0.1511, 0.5990, 0.6338, 0.5137, 0.5203, 0.0560], req
    ...: uires_grad=True)

In [14]: W2 = torch.tensor([[0.2541, 0.6699, 0.5311, 0.7801, 0.5042, 0.5475],
    ...:         [0.7523, 0.1331, 0.7670, 0.8132, 0.0524, 0.0269],
    ...:         [0.3974, 0.2880, 0.9142, 0.9906, 0.4401, 0.3984],
    ...:         [0.7956, 0.2071, 0.2209, 0.6192, 0.2054, 0.7693],
    ...:         [0.8587, 0.8415, 0.6033, 0.3812, 0.2498, 0.9813],
    ...:         [0.9033, 0.0417, 0.2272, 0.1576, 0.9087, 0.3284]], requires_gra
    ...: d=True)

In [15]: y2 = W2 @ x2

In [16]: y2
Out[16]: tensor([1.4700, 1.1260, 1.5721, 0.8523, 1.3969, 0.8776], grad_fn=<MvBackward>)

# Use the indices obtained earlier to construct the matrix
In [19]: _
Out[19]: tensor([2, 0, 4])

In [20]: k = 3

In [21]: Wk = torch.zeros(k, y2.shape[0])

In [22]: Wk[torch.arange(k), _] = 1

In [23]: Wk.requires_grad = True

In [24]: Wk
Out[24]: 
tensor([[0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.]], requires_grad=True)


In [25]: yk2 = Wk @ y2

In [26]: yk2
Out[26]: tensor([1.5721, 1.4700, 1.3969], grad_fn=<MvBackward>)

In [27]: loss2 = (yk2 ** 2).sum()

In [28]: loss2.backward()

现在比较在这两种情况下获得的梯度:

代码语言:javascript
复制
In [29]: W2.grad
Out[29]: 
tensor([[0.4442, 1.7611, 1.8634, 1.5103, 1.5297, 0.1646],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4751, 1.8834, 1.9929, 1.6152, 1.6360, 0.1761],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4222, 1.6735, 1.7707, 1.4352, 1.4536, 0.1565],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

In [30]: W1.grad
Out[30]: 
tensor([[0.4442, 1.7609, 1.8633, 1.5102, 1.5296, 0.1646],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4751, 1.8833, 1.9928, 1.6152, 1.6359, 0.1760],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4222, 1.6734, 1.7706, 1.4352, 1.4535, 0.1564],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

In [31]: x1.grad
Out[31]: tensor([4.3955, 5.2256, 6.1213, 6.4732, 3.5637, 5.6037])

In [32]: x2.grad
Out[32]: tensor([4.3957, 5.2261, 6.1215, 6.4733, 3.5641, 5.6040])

正如您所看到的,结果与一些浮点错误是相同的,这些错误是在我复制x1W1的值时引入的,而没有获得它们的全部精度。

票数 4
EN

Stack Overflow用户

发布于 2021-05-31 08:26:52

多亏了@swag2198,我才弄明白了这一点,我在我的存储库上发布了这个配方

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67570529

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档