首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在一个模型中计算两个损失并反向传播两次

在一个模型中计算两个损失并反向传播两次
EN

Stack Overflow用户
提问于 2020-12-17 09:21:05
回答 1查看 207关注 0票数 0

我正在使用BertModel创建一个模型来识别答案广度(而不使用BertForQA)。

我有一个独立的线性层,用于分别确定开始和结束令牌。在init()中:

代码语言:javascript
复制
self.start_linear = nn.Linear(h, output_dim)

self.end_linear = nn.Linear(h, output_dim)

在前进()中,我输出了一个预测的开始层和预测的结束层:

代码语言:javascript
复制
 def forward(self, input_ids, attention_mask):

 outputs = self.bert(input_ids, attention_mask) # input = bert tokenizer encoding

 lhs = outputs.last_hidden_state # (batch_size, sequence_length, hidden_size)

 out = lhs[:, -1, :] # (batch_size, hidden_dim)

 st = self.start_linear(out)

 end = self.end_linear(out) 



 predict_start = self.softmax(st)

 predict_end = self.softmax(end)

 return predict_start, predict_end

然后在train_epoch()中,我尝试分别反向传播损失:

代码语言:javascript
复制
def train_epoch(model, train_loader, optimizer):

 model.train()

 total = 0

 st_loss, st_correct, st_total_loss = 0, 0, 0

 end_loss, end_correct, end_total_loss = 0, 0, 0

 for batch in train_loader:

   optimizer.zero_grad()

   input_ids = batch['input_ids'].to(device)

   attention_mask = batch['attention_mask'].to(device)

   start_idx = batch['start'].to(device)

   end_idx = batch['end'].to(device)

   start, end = model(input_ids=input_ids, attention_mask=attention_mask)


   st_loss = model.compute_loss(start, start_idx)

   end_loss = model.compute_loss(end, end_idx)

   st_total_loss += st_loss.item()

   end_total_loss += end_loss.item()

 # perform backward propagation to compute the gradients

   st_loss.backward()

   end_loss.backward()

 # update the weights

   optimizer.step() 

但是我接了end_loss.backward()的电话

代码语言:javascript
复制
Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

我应该分开做后传吗?还是我应该用另一种方式来做?谢谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-12-17 13:27:26

标准程序是将损失和反向传播都加在和上。

重要的是要确保您想要的两种损失的平均值大致相同大,或者至少与您希望每个损失相对于另一种损失的重要性成正比(否则,模型将对较大的损失进行优化,而不是对较小的损失进行优化)。在span检测的情况下,由于问题的明显对称性,我猜这是不必要的。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65337804

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档