我正在阅读PyTorch Geometric的文档,地址是https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html
在该页面上,有一个代码片段:
import torchfrom torch_geometric.data import Dataedge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)x = torch.tensor([[-1], [0], [1]], dtype=torch.float)data = Data(x=x, edge_index=edge_index)
上述代码片段最后一行的输出是:
Data(edge_index=[2, 4], x=[3, 1])
edge_index
中的2和4是怎么来的?如果我理解正确的话,这里定义了四个边,索引从0开始。这个假设是否错误?另外,x =[3, 1]
是什么意思?
Data
是一个类,所以我不会期望它返回任何东西。类的定义在这里:https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html。我读了文档,x
应该是节点特征矩阵,而edge_index
应该是图的连通性。但我无法理解我在Jupyter Notebook中验证的控制台输出。
回答:
好的,我想我已经理解了输出Data(edge_index=[2, 4], x=[3, 1])
的含义。这里的[2,4] 是edge_index
的维度,而[3,1] 是x
的维度。如果我错了,请任何人纠正我。