首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow:如何提取用于绘图的attention_scores?

Tensorflow:如何提取用于绘图的attention_scores?
EN

Stack Overflow用户
提问于 2022-01-04 02:02:58
回答 1查看 690关注 0票数 3

如果您在Keras中有一个MultiHeadAttention层,那么它可以返回如下所示的注意力分数:

代码语言:javascript
复制
    x, attention_scores = MultiHeadAttention(1, 10, 10)(x, return_attention_scores=True)

如何从网络图中提取注意力分数?我想给他们画个图。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-01-04 11:25:31

选项1:如果您想在培训期间绘制注意力分数,您可以创建一个Callback并将数据传递给它。它可以被触发,例如,在每一个时代之后。下面是一个例子,我使用了两个注意头,并在每一个时代之后绘制它们:

代码语言:javascript
复制
import tensorflow as tf
import seaborn as sb
import matplotlib.pyplot as plt

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, data):
      self.data = data
   def on_epoch_end(self, epoch, logs=None):
      test_targets, test_sources = self.data 
      _, attention_scores = attention_layer(test_targets[:1], test_sources[:1], return_attention_scores=True) # take one sample

      fig, axs = plt.subplots(ncols=3, gridspec_kw=dict(width_ratios=[5,5,0.2]))
      sb.heatmap(attention_scores[0, 0, :, :], annot=True, cbar=False, ax=axs[0])
      sb.heatmap(attention_scores[0, 1, :, :], annot=True, yticklabels=False, cbar=False, ax=axs[1])
      fig.colorbar(axs[1].collections[0], cax=axs[2])
      plt.show()

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.layers.Input(shape=[8, 16])
source = tf.keras.layers.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
                               return_attention_scores=True)
output = tf.keras.layers.Flatten()(output_tensor)
output = tf.keras.layers.Dense(1, activation='sigmoid')(output)

model = tf.keras.Model([target, source], output)
model.compile(optimizer = 'adam', loss = tf.keras.losses.BinaryCrossentropy())

attention_layer = model.layers[2]
samples = 5
train_targets = tf.random.normal((samples, 8, 16))
train_sources = tf.random.normal((samples, 4, 16))
test_targets = tf.random.normal((samples, 8, 16))
test_sources = tf.random.normal((samples, 4, 16))
y = tf.random.uniform((samples,), maxval=2, dtype=tf.int32)

model.fit([train_targets, train_sources], y, batch_size=2, epochs=2, callbacks=[CustomCallback([test_targets, test_sources])])
代码语言:javascript
复制
Epoch 1/2
1/3 [=========>....................] - ETA: 2s - loss: 0.7142

代码语言:javascript
复制
3/3 [==============================] - 3s 649ms/step - loss: 0.6992
Epoch 2/2
1/3 [=========>....................] - ETA: 0s - loss: 0.7265

代码语言:javascript
复制
3/3 [==============================] - 1s 650ms/step - loss: 0.6863
<keras.callbacks.History at 0x7fcc839dc590>

选项2:如果您只想在训练后绘制注意力分数,则只需将一些数据传递到模型的注意层并绘制分数:

代码语言:javascript
复制
import tensorflow as tf
import seaborn as sb
import matplotlib.pyplot as plt

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.layers.Input(shape=[8, 16])
source = tf.keras.layers.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
                               return_attention_scores=True)
output = tf.keras.layers.Flatten()(output_tensor)
output = tf.keras.layers.Dense(1, activation='sigmoid')(output)

model = tf.keras.Model([target, source], output)
model.compile(optimizer = 'adam', loss = tf.keras.losses.BinaryCrossentropy())

samples = 5
train_targets = tf.random.normal((samples, 8, 16))
train_sources = tf.random.normal((samples, 4, 16))
test_targets = tf.random.normal((samples, 8, 16))
test_sources = tf.random.normal((samples, 4, 16))
y = tf.random.uniform((samples,), maxval=2, dtype=tf.int32)

model.fit([train_targets, train_sources], y, batch_size=2, epochs=2)

attention_layer = model.layers[2]

_, attention_scores = attention_layer(test_targets[:1], test_sources[:1], return_attention_scores=True) # take one sample
fig, axs = plt.subplots(ncols=3, gridspec_kw=dict(width_ratios=[5,5,0.2]))
sb.heatmap(attention_scores[0, 0, :, :], annot=True, cbar=False, ax=axs[0])
sb.heatmap(attention_scores[0, 1, :, :], annot=True, yticklabels=False, cbar=False, ax=axs[1])
fig.colorbar(axs[1].collections[0], cax=axs[2])
plt.show()
代码语言:javascript
复制
Epoch 1/2
3/3 [==============================] - 1s 7ms/step - loss: 0.6727
Epoch 2/2
3/3 [==============================] - 0s 6ms/step - loss: 0.6503

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70573362

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档