如果我正确理解,变压器XL中的Key隐藏层是2L * d大小的,其中L是段长度,d是嵌入维。
两个隐序列沿长度维的级联
因此,注意矩阵的大小将是L X 2L,其中行i表示应该应用于每个2L Keys的注意Query i。
然而,在下面的图像中,片段长度为4,每个节点只有4条线。每个节点不应该有4*2=8行吗?

发布于 2022-11-29 13:36:46
如果您看一下Github代码,实际上,多头注意力函数中就有2xNxD:
class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False):
super(MultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm kv指的是关键向量和值向量。
来源:https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
https://datascience.stackexchange.com/questions/116506
复制相似问题