首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么在提取roBERTa时冻结位置嵌入?

为什么在提取roBERTa时冻结位置嵌入?
EN

Stack Overflow用户
提问于 2020-03-20 18:28:34
回答 1查看 281关注 0票数 1

我对huggingface的distillBERT工作很感兴趣,通过查看他们的代码(https://github.com/huggingface/transformers/blob/master/examples/distillation/train.py),我发现如果使用roBERTa作为学生模型,他们会冻结位置嵌入,我想知道这是为了什么?

代码语言:javascript
复制
def freeze_pos_embeddings(student, args):
    if args.student_type == "roberta":
        student.roberta.embeddings.position_embeddings.weight.requires_grad = False
    elif args.student_type == "gpt2":
        student.transformer.wpe.weight.requires_grad = False

我理解冻结token_type_embeddings的原因,因为roBERTa从来不使用段嵌入,但是为什么是位置嵌入呢?

非常感谢你的帮助!

EN

回答 1

Stack Overflow用户

发布于 2020-03-23 16:05:47

在大多数(甚至所有)常用的Transformers中,位置嵌入不是经过训练的,而是使用分析描述的函数( Attention is all you need论文第6页上的未编号方程)定义的:

为了节省Transformer package中的计算时间,它们被预先计算到512的长度,并存储为用作缓存的变量,该缓存在训练期间不应改变。

不训练位置嵌入的原因是,后面位置的嵌入将训练不足,但通过巧妙地解析定义的位置嵌入,网络可以学习方程背后的规律性,并更容易对更长的序列进行推广。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60772384

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档