首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何创建图神经网络数据集?(pytorch几何图形)

如何创建图神经网络数据集?(pytorch几何图形)
EN

Stack Overflow用户
提问于 2021-03-25 04:06:14
回答 3查看 691关注 0票数 4

如何将我自己的数据集转换为pytorch几何图形神经网络的可用数据集?

所有教程都使用已转换为可由pytorch使用的现有数据集。例如,如果我有自己的点云数据集,我如何使用它来训练图神经网络的分类?我自己的用于分类的图像数据集呢?

EN

回答 3

Stack Overflow用户

发布于 2021-10-28 23:45:09

就像文档中提到的那样。pytorch-geometric

我真的需要使用这些数据集接口吗?不是的!就像在常规PyTorch中一样,您不必使用数据集,例如,当您想要动态创建合成数据而不将其显式保存到磁盘时。在这种情况下,只需传递一个包含torch_geometric.data.Data对象的常规python列表并将它们传递给torch_geometric.loader.DataLoader

代码语言:javascript
复制
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
票数 2
EN

Stack Overflow用户

发布于 2021-03-25 04:22:38

您需要如何转换数据取决于您的模型期望的格式。

图神经网络通常期望(的子集):

  • 节点features
  • edges
  • edge attributes
  • node以

为目标

这取决于问题所在。您可以使用以下Data对象在PyTorch Geometric中使用这些值的张量创建对象(并根据需要扩展属性):

代码语言:javascript
复制
data = Data(x=x, edge_index=edge_index, y=y)
data.train_idx = torch.tensor([...], dtype=torch.long)
data.test_mask = torch.tensor([...], dtype=torch.bool)
票数 1
EN

Stack Overflow用户

发布于 2021-09-25 20:05:44

代码语言:javascript
复制
from torch_geometric.data import Dataset, Data
class MyCustomDataset(Dataset):
    def __init__():
        self.filename = .. # List of raw files, in your case point cloud
        super(MyCustomDataset, self).__init()

    @property
    def raw_file_names(self):
        return self.filename
    
    @property
    def processed_file_names(self):
        """ return list of files should be in processed dir, if found - skip processing."""
        processed_filename = []
        return processed_filename
    def download(self):
        pass

    def process(self):
        for file in self.raw_paths:
            self._process_one_step(file)

    def _process_one_step(self, path):
        out_path = (self.processed_dir, "some_unique_filename.pt")
        # read your point cloud here, 
        # convert point cloud to Data object
        data = Data(x=node_features,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=label #you can add more arguments as you like
                    )
        torch.save(data, out_path)
        return

    def __len__(self):
        return len(self.processed_file_names)

    def __getitem__(self, idx):
        data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx]))
        return data

这将以正确的格式创建数据。然后,您可以使用torch_geometric.data.Dataloader创建数据加载器,然后训练您的网络。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66788555

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档