首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >DGL中以节点为输入、边为输出的图神经网络

DGL中以节点为输入、边为输出的图神经网络
EN

Stack Overflow用户
提问于 2020-06-26 06:33:31
回答 1查看 754关注 0票数 0

我想调整示例DGL GATLayer,以便网络可以学习边权重,而不是学习节点表示。也就是说,我想构建一个网络,它将一组节点特征作为输入并输出边。标签将是一组“真值边”,表示哪些节点来自共同的来源,这样我就可以学习以同样的方式聚类看不见的数据。

我使用以下DGL示例中的代码作为起点:

https://www.dgl.ai/blog/2019/02/17/gat.html

代码语言:javascript
复制
import torch.nn as nn
import torch.nn.functional as F

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
    
    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e' : F.leaky_relu(a)}
    
    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z' : edges.src['z'], 'e' : edges.data['e']}
    
    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h' : h}
    
    def forward(self, h):
        # equation (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # equation (2)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge
    
    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        #   multiple head outputs are concatenated together. Also, only
        #   one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
    
    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

我曾希望我可以将其修改为简单地返回边而不是节点,例如通过替换行

return self.g.ndata.pop('h')

使用

return self.e.ndata.pop('e')

但看起来并不是这么简单。我设法让一些东西运行,但是损失到处都是,并且没有学习发生。

我对图网络是个新手,尽管不是一般意义上的深度学习。我正在尝试做的事情是合理的吗?在我对它的工作原理的理解中,我是否遗漏了一些至关重要的东西?我一直找不到任何易于理解的图网络的例子,其中边本身是学习目标,所以我现在有点困惑。我很感谢任何人能给予的帮助!

EN

回答 1

Stack Overflow用户

发布于 2020-08-17 19:45:54

我不能完全确定,因为它取决于你的输入,但是self.g很可能是一个DGL图,因此在他们访问ndata的代码中,ndata代表节点数据,如果你想访问图的边数据,你可以访问edata。因此,您应该编写返回self.g.edata...即使我不确定你正在尝试访问的边的哪些属性会改变pop(无论你试图访问什么)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62585304

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档