首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch:如何为图注意力层实现注意力

PyTorch:如何为图注意力层实现注意力
EN

Stack Overflow用户
提问于 2018-03-19 16:17:41
回答 1查看 1K关注 0票数 0

我已经实现了注意力(等式)。1)的https://arxiv.org/pdf/1710.10903.pdf,但它显然不是内存效率高,并且只能在我的图形处理器上运行一个模型(它需要7-10 of )。

目前,我有

代码语言:javascript
复制
class MyModule(nn.Module):

def __init__(self, in_features, out_features):
    super(MyModule, self).__init__()
    self.in_features = in_features
    self.out_features = out_features

    self.W = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(in_features, out_features).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
    self.a = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(2*out_features, 1).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)

def forward(self, input):
    h = torch.mm(input, self.W)
    N = h.size()[0]

    a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
    e = F.elu(torch.matmul(a_input, self.a).squeeze(2))
    return e

我对计算所有e_ij术语的见解是

代码语言:javascript
复制
In [8]: import torch

在9中:将numpy导入为np

In 10: h= torch.LongTensor(np.array([1,1,2,2,3,3]))

In 11: N=3

In 12: h.repeat(1,N).view(N * N,-1) Out12:

代码语言:javascript
复制
1     1
1     1
1     1
2     2
2     2
2     2
3     3
3     3
3     3

9x2大小的torch.LongTensor

在13中: h.repeat(N,1) Out13:

代码语言:javascript
复制
1     1
2     2
3     3
1     1
2     2
3     3
1     1
2     2
3     3

9x2大小的torch.LongTensor

并且最后将hs和馈送矩阵a连接起来。

有没有一种对内存更友好的方式呢?

EN

回答 1

Stack Overflow用户

发布于 2018-07-15 21:19:53

也许你可以使用稀疏张量来存储adj_mat

代码语言:javascript
复制
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(np.vstack((sparse_mx.row,
                                          sparse_mx.col))).long()
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)
票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49358396

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档