首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在使用fast-ai时修复为AWD_LSTM加载state_dict时出现的错误

如何在使用fast-ai时修复为AWD_LSTM加载state_dict时出现的错误
EN

Stack Overflow用户
提问于 2019-04-25 18:39:37
回答 1查看 1.2K关注 0票数 2

我使用fast-ai库来训练IMDB评论数据集的样本。我的目标是实现情感分析,我只想从一个小数据集开始(这个数据集包含1000条IMDB评论)。我已经使用this tutorial在VM中训练了该模型。

我保存了data_lmdata_clas模型,然后保存了编码器ft_enc,然后保存了分类器学习器sentiment_model。然后,我从虚拟机中获得了这4个文件,并将它们放入我的机器中,希望使用这些预先训练的模型来对情绪进行分类。

这是我所做的:

代码语言:javascript
复制
# Use the IMDB_SAMPLE file
path = untar_data(URLs.IMDB_SAMPLE)

# Language model data
data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')

# Sentiment classifier model data
data_clas = TextClasDataBunch.from_csv(path, 'texts.csv', 
                                       vocab=data_lm.train_ds.vocab, bs=32)

# Build a classifier using the tuned encoder (tuned in the VM)
learn = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5)
learn.load_encoder('ft_enc')

# Load the trained model
learn.load('sentiment_model')

在那之后,我想使用这个模型来预测句子的情绪。在执行这段代码时,我遇到了以下错误:

代码语言:javascript
复制
RuntimeError: Error(s) in loading state_dict for AWD_LSTM:
   size mismatch for encoder.weight: copying a param with shape torch.Size([8731, 400]) from checkpoint, the shape in current model is torch.Size([8888, 400]).
   size mismatch for encoder_dp.emb.weight: copying a param with shape torch.Size([8731, 400]) from checkpoint, the shape in current model is torch.Size([8888, 400]). 

回溯是:

代码语言:javascript
复制
Traceback (most recent call last):
  File "C:/Users/user/PycharmProjects/SentAn/mainApp.py", line 51, in <module>
    learn = load_models()
  File "C:/Users/user/PycharmProjects/SentAn/mainApp.py", line 32, in load_models
    learn.load_encoder('ft_enc')
  File "C:\Users\user\Desktop\py_code\env\lib\site-packages\fastai\text\learner.py", line 68, in load_encoder
    encoder.load_state_dict(torch.load(self.path/self.model_dir/f'{name}.pth'))
  File "C:\Users\user\Desktop\py_code\env\lib\site-packages\torch\nn\modules\module.py", line 769, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))

因此,在加载编码器时会出现错误。但是,我也尝试删除load_encoder行,但在下一行learn.load('sentiment_model')处发生了相同的错误。

我在fast-ai论坛上搜索了一下,发现其他人也有这个问题,但没有找到解决方案。在this post中,用户说这可能与不同的预处理有关,尽管我不能理解为什么会发生这种情况。

有谁知道我做错了什么吗?

EN

回答 1

Stack Overflow用户

发布于 2019-06-13 08:54:03

似乎data_clas和data_lm的词汇量是不同的。我猜这个问题是由data_clas和data_lm中使用的不同预处理引起的。为了验证我的猜测,我简单地使用了

data_clas.vocab.itos = data_lm.vocab.itos

在下面这行之前

learn_c = text_classifier_learner(data_clas,AWD_LSTM,drop_mult=0.3)

这已经修复了这个错误。

票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55847371

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档