当我们将输入作为节点特征(x)和边缘索引(edge_index)传递给pytorch_geometric层(例如GATConv)时,我担心该层是否能区分给定节点元素属于哪个批次样本。
X遵循节点的形状数量,特征大小和edge_index遵循形状2,边的数量。然而,这两个没有给定的信息来知道哪些批次大小为32的输入图在x中具有给定的节点特征。
有人能澄清这一点吗?
发布于 2021-03-02 11:18:58
PyTorch-Geometric将批处理中的所有图形视为单个巨大的图形,各个图形彼此断开连接。节点索引对应于这个大图中的节点。这意味着在x或edge_index中不需要批处理维度。
https://stackoverflow.com/questions/66423622
复制相似问题