使用Stellargraph进行数据 shuffling 时出现问题

当我运行Stellargraph关于使用DGCNN进行图分类演示时,得到的结果与演示中相同。

然而,当我首先使用以下代码打乱数据时,测试结果却有所不同:

shuffler = list(zip(graphs, graph_labels))random.shuffle(shuffler)graphs, graph_labels = zip(*shuffler)

模型完全没有学习(准确率约为50% – 与数据分布一致)。

有谁知道为什么会这样吗?可能是我的打乱方式不对?还是说数据本来就不应该打乱(如果是这样,为什么呢?这似乎没有道理)?或者这是Stellargraph实现中的一个bug?


回答:

我找到了问题所在。问题与打乱算法或Stellargraph的实现无关。问题出在演示中的以下几行代码:

train_gen = gen.flow(    list(train_graphs.index - 1),    targets=train_graphs.values,    batch_size=50,    symmetric_normalization=False,)test_gen = gen.flow(    list(test_graphs.index - 1),    targets=test_graphs.values,    batch_size=1,    symmetric_normalization=False,)

问题具体是由train_graphs.index - 1test_graphs.index - 1引起的。索引已经在0n的范围内,因此从中减去1会导致图数据“向后移动”一个位置,使得每个数据点获得了另一个数据点的标签。

要解决这个问题,只需将它们改为train_graphs.indextest_graphs.index,去掉末尾的-1即可。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注