首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >att_mask与key_padding_mask在MultiHeadAttnetion中的区别

att_mask与key_padding_mask在MultiHeadAttnetion中的区别
EN

Stack Overflow用户
提问于 2020-06-29 00:31:03
回答 2查看 4.4K关注 0票数 15

att_maskkey_padding_maskMultiHeadAttnetion中的区别是什么?

key_padding_mask -如果提供,键中指定的填充元素将被注意忽略。当给定二进制掩码且值为True时,注意层上的相应值将被忽略。当给定字节掩码且值为非零时,注意层上的相应值将被忽略attn_mask - 2D或3D掩码,这会阻止对某些位置的注意。2D掩码将用于所有批次的广播,而3D掩码允许为每个批的条目指定不同的掩码。

提前谢谢。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-06-29 07:51:51

key_padding_mask用于屏蔽正在填充的位置,即在输入序列结束后。这总是特定于输入批处理,并取决于与最长的批处理相比,批中的序列是多长时间。它是形状、批次大小×输入长度的二维张量。

另一方面,attn_mask说什么键值对是有效的.在变压器译码器中,采用三角形掩码来模拟推理时间,防止对“未来”位置的注意。这就是通常使用att_mask的地方。如果是二维张量,则形状为输入长度×输入长度。您还可以拥有一个特定于批处理中的每个项目的掩码。在这种情况下,您可以使用三维张量的形状(批大小×num头)×输入长度×输入长度。(因此,从理论上讲,您可以使用三维key_padding_mask来模拟att_mask。)

票数 18
EN

Stack Overflow用户

发布于 2021-12-06 21:09:56

我认为它们的工作原理是一样的:两个掩码都定义了查询和键之间的注意不会被使用。这两种选择的唯一不同之处在于您更适合输入掩码的形状。

根据代码,这两个掩码似乎是合并/合并的,因此它们都扮演着相同的角色--查询和键之间的注意不会被使用。如果您需要使用两个掩码,则两个掩码输入可以是不同的值,或者您可以根据其所需的形状以方便的mask_args输入掩码:下面是函数multi_head_attention_forward()中第5227行吡咯烷酮/功能。原始代码的一部分。

代码语言:javascript
复制
...
# merge key padding and attention masks
if key_padding_mask is not None:
    assert key_padding_mask.shape == (bsz, src_len), \
        f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
    key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).   \
        expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
    if attn_mask is None:
        attn_mask = key_padding_mask
    elif attn_mask.dtype == torch.bool:
        attn_mask = attn_mask.logical_or(key_padding_mask)
    else:
        attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
...
# so here only the merged/unioned mask is used to actually compute the attention
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)

如果你有不同的意见或者我错了,请纠正我。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62629644

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档