在PyTorch几何学关于创建消息传递网络的教程中,他们在解释类功能的开头有这样一段话:
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
: 定义了使用的聚合方案("add", "mean" 或 "max"
)以及消息传递的流动方向("source_to_target"
或"target_to_source"
)。此外,node_dim
属性指示沿哪个轴进行传播。
我不理解这个node_dim
指的是什么,以及为什么它是-2。我查看了MessagePassing
类的文档,那里提到它是进行传播的轴——这仍然没有真正澄清我们在这里做什么,以及为什么默认值是-2(可能是这样在节点级别上传播信息)。请问有人可以给我解释一下吗?
回答:
参考这里和这里后,我认为与之相关的是‘message’函数的输出。
在大多数情况下,输出的形状是[edge_num, emb_out]
,如果我们将node_dim
设置为-2,就意味着我们将使用目标节点的索引沿edge_num
进行聚合。
这正是从源节点聚合信息的过程。
聚合后的结果是[node_num, emb_out]
。