首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >变压器XL -理解纸的插图

变压器XL -理解纸的插图
EN

Data Science用户
提问于 2022-11-26 17:09:22
回答 1查看 62关注 0票数 2

如果我正确理解,变压器XL中的Key隐藏层是2L * d大小的,其中L是段长度,d是嵌入维。

两个隐序列沿长度维的级联

因此,注意矩阵的大小将是L X 2L,其中行i表示应该应用于每个2L Keys的注意Query i

,即自我注意窗口长度=2X段长.

然而,在下面的图像中,片段长度为4,每个节点只有4条线。每个节点不应该有4*2=8行吗?

连接变压器XL纸

EN

回答 1

Data Science用户

发布于 2022-11-29 13:36:46

如果您看一下Github代码,实际上,多头注意力函数中就有2xNxD:

代码语言:javascript
复制
class MultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, 
                 pre_lnorm=False):
        super(MultiHeadAttn, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
        self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

        self.layer_norm = nn.LayerNorm(d_model)

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm 

kv指的是关键向量和值向量。

来源:https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py

票数 1
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/116506

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档