首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何理解google转换器教程中的自我注意掩码实现

如何理解google转换器教程中的自我注意掩码实现
EN

Stack Overflow用户
提问于 2022-09-24 07:14:34
回答 1查看 125关注 0票数 1

我正在阅读谷歌变压器教程,我不清楚为什么可以通过mask1 & mask2构建多头注意力的attention_mask。任何帮助都会很好!

代码语言:javascript
复制
  def call(self, x, training, mask):

    # A boolean mask.
    if mask is not None:
      mask1 = mask[:, :, None]
      mask2 = mask[:, None, :]
      attention_mask = mask1 & mask2          # <= here 
    else:
      attention_mask = None

    # Multi-head self-attention output (`tf.keras.layers.MultiHeadAttention `).
    attn_output = self.mha(
        query=x,  # Query Q tensor.
        value=x,  # Value V tensor.
        key=x,  # Key K tensor.
        attention_mask=attention_mask, # A boolean mask that prevents attention to certain positions.
        training=training, # A boolean indicating whether the layer should behave in training mode.
        )

玩具示例故障

代码语言:javascript
复制
input = tf.constant([
    [[1, 0, 3, 0], [1, 2, 0, 0]]
])

mask = tf.keras.layers.Embedding(2,2, mask_zero=True).compute_mask(input)
print(mask)
mask1 = mask[:, :, None]   # same as tf.expand_dims(mask, axis = 2)
print(mask1)
mask2 = mask[:, None, :]
print(mask2)

print(mask1 & mask2)

>

tf.Tensor(
[[[ True False  True False]
  [ True  True False False]]], shape=(1, 2, 4), dtype=bool)

tf.Tensor(
[[[[ True False  True False]]

  [[ True  True False False]]]], shape=(1, 2, 1, 4), dtype=bool)

tf.Tensor(
[[[[ True False  True False]
   [ True  True False False]]]], shape=(1, 1, 2, 4), dtype=bool)

<tf.Tensor: shape=(1, 2, 2, 4), dtype=bool, numpy=               # <= why built mask like this?
array([[[[ True, False,  True, False],
         [ True, False, False, False]],

        [[ True, False, False, False],
         [ True,  True, False, False]]]])>
EN

回答 1

Stack Overflow用户

发布于 2022-10-10 13:44:55

以下是我的理解。如果我错了就纠正我。

我认为理解注意掩码计算的关键是,多头注意的attention_mask和嵌入层产生的嵌入掩码的区别。

tf.keras.layers.Embedding是一个掩码生成层.

在输入形状为(batch_size,input_length)的情况下,tf.keras.layers.Embedding生成相同形状的嵌入掩码(batch_size,input_length),(docs/python/tf/keras/layers/Embedding#input-shape);

tf.keras.layers.MultiHeadAttention是掩模消耗层.

tf.keras.layers.Embedding的输出张量传递给tf.keras.layers.MultiHeadAttention时,嵌入掩码也需要传递给后一层。但是tf.keras.layers.MultiHeadAttention需要"attention_mask",这与嵌入掩码不同。"attention_mask“是形状(B,T,S) (1)的布尔掩码。B表示batch_size,T表示目标或查询,S表示源或键。

为了计算自我注意的注意掩码,我们基本上需要做一个外部乘积(产品)。这意味着,对于行令牌序列$X$,我们需要执行$X^T X$。结果是注意矩阵,其中每个元素都是从一个词到另一个词的注意力。注意力掩码将以同样形状的矩阵形式出现。

&运算符在mask1 & mask2中是tf.math.logical_and

理解tf.keras.layers.MultiHeadAttention中注意掩码的一个基本例子

代码语言:javascript
复制
sequence_a = "This is a very long sequence"
sequence_b = "This is short"

text = (sequence_a + ' ' + sequence_b).split(' ')

from sklearn import preprocessing
le = preprocessing.LabelEncoder()
le.fit(text)
print(le.classes_)

['This' 'a' 'is' 'long' 'sequence' 'short' 'very']

_tokens_a = le.transform(sequence_a.split(' ')) + 1 # 1-based
# print(_tokens_a)
_tokens_b = le.transform(sequence_b.split(' ')) + 1
# print(_tokens_b)

pad_b = tf.constant([[0,_tokens_a.size - _tokens_b.size]])
tokens_b = tf.pad(_tokens_b, pad_b)
tokens_a = tf.constant(_tokens_a)
print(tokens_a)

tf.Tensor([1 3 2 7 4 5], shape=(6,), dtype=int64)

print(tokens_b)

tf.Tensor([1 3 6 0 0 0], shape=(6,), dtype=int64)

padded_batch = tf.concat([tokens_a[None,:], tokens_b[None,:]], axis=0)
padded_batch  # Shape `(batch_size, input_seq_len)`.

标记化结果:

代码语言:javascript
复制
<tf.Tensor: shape=(2, 6), dtype=int64, numpy=
array([[1, 3, 2, 7, 4, 5],
       [1, 3, 6, 0, 0, 0]])>

嵌入口罩和注意面罩:

代码语言:javascript
复制
embedding = tf.keras.layers.Embedding(10, 4, mask_zero=True)
embedding_batch = embedding(padded_batch)
embedding_batch


<tf.Tensor: shape=(2, 6, 4), dtype=float32, numpy=
array([[[-0.0395105 ,  0.02781621, -0.02362361,  0.01861998],
        [ 0.02881015,  0.03395045, -0.0079098 , -0.002824  ],
        [ 0.02268535, -0.02632991,  0.03217204, -0.03376112],
        [ 0.04794324,  0.01584867,  0.02413819,  0.01202248],
        [-0.03509659,  0.04907972, -0.00174795, -0.01215838],
        [-0.03295932,  0.02424154, -0.04788723, -0.03202241]],

       [[-0.0395105 ,  0.02781621, -0.02362361,  0.01861998],
        [ 0.02881015,  0.03395045, -0.0079098 , -0.002824  ],
        [-0.02425164, -0.04932282,  0.0186419 , -0.01743554],
        [-0.00052293,  0.01411307, -0.01286217,  0.00627784],
        [-0.00052293,  0.01411307, -0.01286217,  0.00627784],
        [-0.00052293,  0.01411307, -0.01286217,  0.00627784]]],
      dtype=float32)>



embedding_mask = embedding_batch._keras_mask  # embedding.compute_mask(padded_batch)
embedding_mask

<tf.Tensor: shape=(2, 6), dtype=bool, numpy=
array([[ True,  True,  True,  True,  True,  True],
       [ True,  True,  True, False, False, False]])>




#This is self attention, thus Q and K are the same
my_mask1 = embedding_mask[:, :, None]  # eq: td[:,:,tf.newaxis]
my_mask1

<tf.Tensor: shape=(2, 6, 1), dtype=bool, numpy=
array([[[ True],
        [ True],
        [ True],
        [ True],
        [ True],
        [ True]],

       [[ True],
        [ True],
        [ True],
        [False],
        [False],
        [False]]])>

#This is self attention, thus Q and K are the same
my_mask2 = embedding_mask[:, None, :]
my_mask2

<tf.Tensor: shape=(2, 1, 6), dtype=bool, numpy=
array([[[ True,  True,  True,  True,  True,  True]],

       [[ True,  True,  True, False, False, False]]])>

#According to the `attention_mask` argument of `tf.keras.layers.MultiHeadAttention` (https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention#call-arguments_1), this is the attention_mask which is a boolean mask of shape (B, T, S)
my_attention_mask = my_mask1 & my_mask2
my_attention_mask  #[batch_size, input_seq_len, input_seq_len]

<tf.Tensor: shape=(2, 6, 6), dtype=bool, numpy=
array([[[ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True]],

       [[ True,  True,  True, False, False, False],
        [ True,  True,  True, False, False, False],
        [ True,  True,  True, False, False, False],
        [False, False, False, False, False, False],
        [False, False, False, False, False, False],
        [False, False, False, False, False, False]]])>
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73835386

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档