首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >不带节点属性的DGL分类图

不带节点属性的DGL分类图
EN

Stack Overflow用户
提问于 2022-02-07 15:19:29
回答 2查看 564关注 0票数 2

我按照本指南从我自己的数据中创建用于图形分类的数据集:data.html

在那里,它们不创建任何节点的特性,因为如果要预测图类,就没有必要。在我的例子中,它是相同的,我不想使用任何节点特性(还)来进行分类。

为了训练GNN,我遵循本教程:classification.html#sphx-glr-tutorials-blitz-5-graph-classification-py

两者都来自官方文档,但它们似乎不兼容,因为当我试图将它们一起使用时,我收到了以下错误:

代码语言:javascript
复制
KeyError                                  Traceback (most recent call last) <ipython-input-39-8a94f1fa250d> in <module>
      4 for epoch in range(20):
      5     for batched_graph, labels in train_dataloader:
----> 6         pred = model(batched_graph, batched_graph.ndata['attr'].float())
      7         loss = F.cross_entropy(pred, labels)
      8         optimizer.zero_grad()

~/anaconda3/lib/python3.8/site-packages/dgl/view.py in
__getitem__(self, key)
     64             return ret
     65         else:
---> 66             return self._graph._get_n_repr(self._ntid, self._nodes)[key]
     67 
     68     def __setitem__(self, key, val):

~/anaconda3/lib/python3.8/site-packages/dgl/frame.py in
__getitem__(self, name)
    391             Column data.
    392         """
--> 393         return self._columns[name].data
    394 
    395     def __setitem__(self, name, data):

KeyError: 'attr'

而且,我找不到另一个例子,不使用节点的特性来使用DGl来训练GNN。有可能吗?我必须创建假属性吗?

谢谢!

EN

回答 2

Stack Overflow用户

发布于 2022-02-08 16:46:03

DGL模型总是需要至少有一个特性。因此,我使用分类器中的度特征来解决这个问题:

代码语言:javascript
复制
class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        h = g.in_degrees().view(-1, 1).float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)
票数 1
EN

Stack Overflow用户

发布于 2022-06-08 15:53:02

当数据集中的节点特性是任务或它与dataset模块中的定义不同时,我得到了此错误。

代码语言:javascript
复制
epoch_losses = []
for epoch in range(200):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss) # this 

根据下面的教程,我假设您定义了图形节点特性g.ndata['h']而不是batched_graph.ndata['attr'] --特别是属性的命名。

模式训练损失曲线

你可能会觉得这很有帮助

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71020966

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档