首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >掩蔽在语言理解转换器的scaled_dot_product_attention中是如何工作的?

掩蔽在语言理解转换器的scaled_dot_product_attention中是如何工作的?
EN

Stack Overflow用户
提问于 2021-02-19 22:57:10
回答 1查看 203关注 0票数 0

我一直在关注Tensorflow关于Transformers的语言理解教程。(here)。但是,我对函数scaled_dot_product_attention中使用的掩码有点困惑。我知道掩码是用来做什么的,但我确实知道它们在这个函数中是如何工作的。

当我学习本教程时,我了解到掩码将有一个矩阵,指示哪些元素是填充元素(掩码矩阵中的值为1),哪些不是(掩码矩阵中的值为0)。例如:

代码语言:javascript
复制
[0 , 0 , 1 
 1 , 0 , 0 
 0 , 1 , 0 ]

但是,我可以看到函数scaled_dot_product_attention试图用一个非常大(或很小)的数字-1e9 (负10亿)来更新填充元素。这可以在提到的函数的以下行中看到:

代码语言:javascript
复制
      if mask is not None:
    scaled_attention_logits += (mask * -1e9)

为什么要这样做?这在数学上是如何导致忽略这些值的呢?以下是本教程中显示的实现:

代码语言:javascript
复制
   def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-04-19 07:57:29

好的,所以值-1e9类似于负无穷大。因此,softmax函数将对这些元素产生0的概率,并且在计算关注值时将被忽略。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66279882

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档