Pytorch geometric: 如何解释下面的代码片段中的输入?

我正在阅读PyTorch Geometric的文档,地址是https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html

在该页面上,有一个代码片段:

import torchfrom torch_geometric.data import Dataedge_index = torch.tensor([[0, 1, 1, 2],                           [1, 0, 2, 1]], dtype=torch.long)x = torch.tensor([[-1], [0], [1]], dtype=torch.float)data = Data(x=x, edge_index=edge_index)

上述代码片段最后一行的输出是:

Data(edge_index=[2, 4], x=[3, 1])

edge_index 中的2和4是怎么来的?如果我理解正确的话,这里定义了四个边,索引从0开始。这个假设是否错误?另外,x =[3, 1] 是什么意思?

Data 是一个类,所以我不会期望它返回任何东西。类的定义在这里:https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html。我读了文档,x 应该是节点特征矩阵,而edge_index 应该是图的连通性。但我无法理解我在Jupyter Notebook中验证的控制台输出。


回答:

好的,我想我已经理解了输出Data(edge_index=[2, 4], x=[3, 1]) 的含义。这里的[2,4] 是edge_index 的维度,而[3,1] 是x 的维度。如果我错了,请任何人纠正我。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注