我进行图分类,读取边列表文件,将其分为训练集和测试集,并使用networkx检查训练集是否为连通图:
如果我想使用10折交叉验证,并且使用类似于以下代码:
from sklearn.model_selection import KFoldKFold(n_splits=10, random_state=None, shuffle=False)
如何确保每个折叠的训练集都是连通图呢?
我该如何添加这样的条件呢?
回答:
我仍然不太理解你的目标,我会尽力回答
基本上,KFold返回数据集的分割索引,它的输入是一个类数组对象,这是一个“愚蠢”的函数,它只是按索引分割数组,没有进一步的逻辑…
你需要创建具有以下结构的数据集:
类数组对象,其中每个项目都是一个连通图(如果适用,还需要另一个数组作为其标签
y
)
一旦你有了这个对象,你就可以使用KFold()
来创建多个测试和训练集
为了创建一组连通图,你可以选择多种方法,例如:
- 移除X个随机边并检查图是否连通(重复N次)
- 如你在评论中提到的,以增量方式添加边和节点,按顺序添加边,并在每次添加后检查图是否连通
这些方法的问题是它们都非常耗时,并且可能不收敛,这高度依赖于它们的输入(随机状态和边的顺序)
我的建议是找到图的最小生成树(MST),这是基本的连通图,要创建它的变体,只需随机添加任何不在MST中的边/边即可