首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在教育模式培训中处理NotImplementedError?

如何在教育模式培训中处理NotImplementedError?
EN

Stack Overflow用户
提问于 2022-07-06 04:18:04
回答 1查看 96关注 0票数 0
代码语言:javascript
复制
def train_fn(data_loader, model, optimizer):

model.train()
total_loss = 0.0

for images, masks in tqdm(data_loader):

  images = images.to(DEVICE)
  masks = masks.to(DEVICE)

  optimizer.zero_grad()
  logits, loss = model(images,masks)
  loss.backward()
  optimizer.step()

  total_loss += loss.item()



return total_loss/ len(data_loader)


def eval_fn(data_loader, model):

model.eval()
total_loss = 0.0

with torch.no_grad():

  for images, masks in tqdm(data_loader):

    images = images.to(DEVICE)
    masks = masks.to(DEVICE)

    logits, loss = model(images,masks)


    total_loss += loss.item()


return total_loss/ len(data_loader)

optimizer = torch.optim.Adam(model.parameters(), lr = LR)

best_valid_loss = np.Inf

for i in range(EPOCHS):


train_loss = train_fn(trainloader, model, optimizer)
valid_loss = eval_fn(validloader, model)

if valid_loss < best_valid_loss:
  torch.save(model.state_dict(), 'best_model.pt')
  print("SAVED_MODEL")
  best_valid_loss = valid_loss

打印(f“时代:{i+1} Train_loss:{train_loss} Valid_loss:{valid_loss}")

当我试图训练模型时,我得到了以下错误:

0%\x{e76f} 0/15 00:00<?

NotImplementedError跟踪(最近一次调用) in () 4 5->6 train_loss =train_fn(列车装载器、模型、优化器)7 valid_loss =eval_fn(验证加载程序,模型)8

2帧/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py在_forward_unimplemented(self,*input) 199注册钩子,而后者无声地忽略它们。200“”-> 201 #举重NotImplementedError 202 203

NotImplementedError:

我该怎么处理呢?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-07-06 20:44:15

查看注释中提供的链接,您的模型定义如下所示:

代码语言:javascript
复制
class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

    def forward(self, images, masks = None):
      logits = self.arc(images)

      if masks != None:
        loss1 = DiceLoss(mode = 'binary')(logits, masks)
        loss2 = nn.BCEWithLogitsLoss()(logits,masks)
        return logits, loss1 + loss2

      return logits

如果您仔细观察,您将看到forward()有一个不规则的额外缩进,使它成为__init__()内部的一个函数,而不是SegmentationModel的一个方法。把它往左移一点,它应该工作得很好:

代码语言:javascript
复制
class SegmentationModel(nn.Module):

  def __init__(self):
    super(SegmentationModel,self).__init__()

    self.arc = smp.Unet(
        encoder_name = ENCODER,
        encoder_weights = WEIGHTS,
        in_channels = 3,
        classes = 1,
        activation = None
    )

  def forward(self, images, masks = None):
    logits = self.arc(images)

    if masks != None:
      loss1 = DiceLoss(mode = 'binary')(logits, masks)
      loss2 = nn.BCEWithLogitsLoss()(logits,masks)
      return logits, loss1 + loss2

    return logits
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72877910

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档