首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >加载预训练的BERT模型问题

加载预训练的BERT模型问题
EN

Stack Overflow用户
提问于 2021-03-02 23:52:08
回答 1查看 465关注 0票数 3

我正在使用Huggingface进一步训练BERT模型。我使用两种方法保存模型:步骤(1)使用以下代码保存整个模型:model.save_pretrained(save_location),以及步骤(2)使用以下代码保存模型的state_dict:torch.save(model.state_dict(),'model.pth')然而,当我尝试使用步骤(1)的代码bert_mask_lm = BertForMaskedLM.from_pretrained('save_location')和步骤(2)的torch.load('model.pth')加载这个预先训练好的BERT模型时,我在两个步骤中都得到了以下错误:

代码语言:javascript
复制
AttributeError                            Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/torch/serialization.py in _check_seekable(f)
    307     try:
--> 308         f.seek(f.tell())
    309         return True

AttributeError: 'torch._C.PyTorchFileReader' object has no attribute 'seek'

During handling of the above exception, another exception occurred:

步骤(1)的详细堆栈跟踪如下:

代码语言:javascript
复制
AttributeError                            Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/torch/serialization.py in _check_seekable(f)
    307     try:
--> 308         f.seek(f.tell())
    309         return True

AttributeError: 'torch._C.PyTorchFileReader' object has no attribute 'seek'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   1037             try:
-> 1038                 state_dict = torch.load(resolved_archive_file, map_location="cpu")
   1039             except Exception:

~/anaconda3/lib/python3.6/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    593                     return torch.jit.load(opened_file)
--> 594                 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    595         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

~/anaconda3/lib/python3.6/site-packages/moxing/framework/file/file_io_patch.py in _load(f, map_location, pickle_module, **pickle_load_args)
    199 
--> 200     _check_seekable(f)
    201     f_should_read_directly = _should_read_directly(f)

~/anaconda3/lib/python3.6/site-packages/torch/serialization.py in _check_seekable(f)
    310     except (io.UnsupportedOperation, AttributeError) as e:
--> 311         raise_err_msg(["seek", "tell"], e)
    312     return False

~/anaconda3/lib/python3.6/site-packages/torch/serialization.py in raise_err_msg(patterns, e)
    303                                 + " try to load from it instead.")
--> 304                 raise type(e)(msg)
    305         raise e

AttributeError: 'torch._C.PyTorchFileReader' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

During handling of the above exception, another exception occurred:

OSError                                   Traceback (most recent call last)
~/work/algo-FineTuningBert3/FineTuningBert3.py in <module>()
      1 #Model load checking
----> 2 loadded_model = BertForMaskedLM.from_pretrained('/cache/raw_model/')

~/anaconda3/lib/python3.6/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   1039             except Exception:
   1040                 raise OSError(
-> 1041                     f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
   1042                     f"at '{resolved_archive_file}'"
   1043                     "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "

OSError: Unable to load weights from pytorch checkpoint file for '/cache/raw_model/' at '/cache/raw_model/pytorch_model.bin'If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. 

我使用的是最新的torch (1.7.1)和transformers (4.3.3)包。我不清楚是什么导致了这个错误,以及如何解决这个问题。

EN

回答 1

Stack Overflow用户

发布于 2021-09-06 07:41:36

我也在经历同样的事情。事实证明,这可能是由于PyTorch和转换器的版本差异造成的。它必须是版本特定的。

我在没有下载最新的bert-base-uncased模型的情况下使用了以下内容:

代码语言:javascript
复制
pip install torch==1.5.1
pip install transformers==3.0.2

MODEL_NAME = 'bert-base-uncased'
model = BertForTokenClassification.from_pretrained(
    MODEL_NAME
)

这将自动下载与适当版本的transformers相关的预训练BERT模型注意:我单独从官方网站显式下载了vocab.txt,并将其与BERT tokenizer类一起使用。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66442648

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档