首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在nn.Module之外创建和使用PyTorch可学习标量变量?

如何在nn.Module之外创建和使用PyTorch可学习标量变量?
EN

Stack Overflow用户
提问于 2020-11-23 13:45:12
回答 1查看 298关注 0票数 1

我正在处理一个多目标问题,其中我需要计算多个损失,而总损失就是这些损失的总和。我想让PyTorch可学习的浮点参数alphabeta作为单个损失的系数。请注意,损失的总和发生在训练循环中我的NN模型之外:

代码语言:javascript
复制
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)

for batch in dl:

    optimizer.zero_grad()

    result = model(batch)

    loss1 = loss_fn_1(result)
    loss2 = loss_fn_2(result)
    loss3 = loss_fn_3(result)

    loss = alpha*loss1 + beta*loss2 + (1-beta)*loss3 # How to optimize alpha, beta?
                        
    loss.backward()
    optimizer.step()

如何声明和使用可学习参数alphabeta

EN

回答 1

Stack Overflow用户

发布于 2020-11-23 15:47:18

您可以将它们放到列表中,然后将它们添加到优化器中,例如,

代码语言:javascript
复制
optimizer_for_my_params = torch.Adam([alpha, beta], lr=1e-3)

或者分开,

代码语言:javascript
复制
optimizer_alpha = torch.Adam([alpha], lr=1e-3)
optimizer_beta = torch.Adam([beta], lr=1e-3)

在每个步骤中,在所有优化器上调用zero_gradstep

或者,您可以将它们放在nn.Module中并将其声明为参数:

代码语言:javascript
复制
class MyParams(nn.Module):
  def __init__(self):
    super(MyParams, self).__init__()

    self.alpha = nn.Parameter(torch.tensor(0.))
    self.beta = nn.Parameter(torch.tensor(0.))

  def forward(self, loss1, loss2, loss3):
    loss = self.alpha*loss1 + self.beta*loss2 + (1 - self.beta)*loss3
    return loss

在使用它时,为类对象定义一个单独的优化器就可以完成这项工作。

更新:这里是第一种方法的更全面的示例。

代码语言:javascript
复制
import torch
import torch.optim as optim

alpha = torch.tensor(0.)
alpha.requires_grad = True
optimizer_alpha = optim.Adam([alpha], lr=1e-3)

print(optimizer_alpha)
# Adam (
# Parameter Group 0
#     amsgrad: False
#     betas: (0.9, 0.999)
#     eps: 1e-08
#     lr: 0.001
#     weight_decay: 0
# )

out = alpha + 1

# test backward()
optimizer_alpha.zero_grad()
out.backward()
print(alpha.grad)
# tensor(1.)

# test step()
optimizer_alpha.step()
print(alpha)
# tensor(-0.0010, requires_grad=True)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64963125

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档