首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >集合面预训练模型的令牌器和模型对象具有不同的最大输入长度

集合面预训练模型的令牌器和模型对象具有不同的最大输入长度
EN

Stack Overflow用户
提问于 2022-03-31 10:49:51
回答 2查看 617关注 0票数 0

我用的是symanto/sn-xlm-roberta-base-snli-mnli-anli-xnli经过预先训练的拥抱脸模型。我的任务需要在相当大的文本上使用它,所以必须知道最大输入长度。

以下代码应该加载预先训练过的模型及其标记器:

代码语言:javascript
复制
encoding_model_name = "symanto/sn-xlm-roberta-base-snli-mnli-anli-xnli"
encoding_tokenizer = AutoTokenizer.from_pretrained(encoding_model_name)
encoding_model = SentenceTransformer(encoding_model_name)

所以,当我打印关于他们的信息时:

代码语言:javascript
复制
encoding_tokenizer
encoding_model

我得到了:

代码语言:javascript
复制
PreTrainedTokenizerFast(name_or_path='symanto/sn-xlm-roberta-base-snli-mnli-anli-xnli', vocab_size=250002, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})
代码语言:javascript
复制
SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: XLMRobertaModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

如您所见,令牌程序中的model_max_len=512参数与模型中的max_seq_length=128参数不匹配

我怎么知道哪一个是真的?或者,如果它们以某种方式响应了不同的特性,那么我如何检查模型的最大输入长度呢?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2022-04-02 01:55:14

由于您使用的是SentenceTransformer并将其加载到SentenceTransformer类,因此它将在128个令牌处截断输入,如文档所述(相关代码为这里):

属性max_seq_length 属性获取模型的最大输入序列长度。较长的输入将被截断。

你也可以自己检查一下:

代码语言:javascript
复制
fifty = model.encode(["This "*50], convert_to_tensor=True)
two_hundered = model.encode(["This "*200], convert_to_tensor=True)
four_hundered = model.encode(["This "*400], convert_to_tensor=True)

print(torch.allclose(fifty, two_hundered))
print(torch.allclose(two_hundered,four_hundered))

输出:

代码语言:javascript
复制
False
True

底层模型(Xlm)能够处理最多512个令牌的序列,但我假设西曼托将其限制为128个,因为它们在训练过程中也使用了这个限制(也就是说,对于超过128个令牌的序列,嵌入可能不是很好)。

票数 2
EN

Stack Overflow用户

发布于 2022-04-01 11:06:02

Model_max_length是该模型的最大位置嵌入长度。要检查这一点,请执行print(model.config)操作,您将看到"max_position_embeddings": 512和其他信任。

如何检查我的模型的最大输入长度?

在对文本序列进行编码时,可以传递max_length(尽可能多地传递模型):tokenizer.encode(txt, max_length=512)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71691184

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档