首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >多头注意度计算

多头注意度计算
EN

Stack Overflow用户
提问于 2022-12-04 13:49:54
回答 1查看 10关注 0票数 0

我创建了一个多头关注层的模型,

代码语言:javascript
复制
import torch
import torch.nn as nn
query = torch.randn(2, 4)
key = torch.randn(2, 4)
value = torch.randn(2, 4)
model = nn.MultiheadAttention(4, 1, bias=False)
代码语言:javascript
复制
model(query, key, value)

我试着匹配得到的注意力输出,

代码语言:javascript
复制
softmax_output = torch.softmax(((query@model.in_proj_weight[:4])@((key@model.in_proj_weight[4:8]).t()))/2, dim=1)
intermediate_output = softmax_output@(value@model.in_proj_weight[8:12])
final_output = intermediate_output@model.out_proj.weight

但是final_output与注意力输出不匹配。

EN

回答 1

Stack Overflow用户

发布于 2022-12-04 14:27:54

能够与输出相匹配,

代码语言:javascript
复制
q_w = query@model.in_proj_weight[:4].t()
k_w = key@model.in_proj_weight[4:8].t()
v_w = value@model.in_proj_weight[8:12].t()

softmax_output = torch.softmax((q_w@k_w.t())/2, dim=1)

attention = softmax_output@v_w

final_output = attention@model.out_proj.weight.t()

丢失了先前的转位

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74677218

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档