在Keras中进行交叉验证

我在Keras中实现了一个多层感知器,并使用scikit-learn来进行交叉验证。为此,我受到了在Keras中进行交叉验证问题中找到的代码的启发

from sklearn.cross_validation import StratifiedKFolddef load_data():    # 使用此函数加载数据def create model():    # 使用此函数创建模型def train_and_evaluate__model(model, data[train], labels[train], data[test], labels[test]):    # 在此处进行拟合和评估if __name__ == "__main__":    X, Y = load_model()    kFold = StratifiedKFold(n_splits=10)    for train, test in kFold.split(X, Y):        model = None        model = create_model()        train_evaluate(model, X[train], Y[train], X[test], Y[test])

在我的神经网络研究中,我了解到神经网络的知识表示存在于突触权重中,并且在网络追踪过程中,更新权重以减少网络错误率并提高其性能。(在我的情况下,我使用的是监督学习)

为了更好地训练和评估神经网络的性能,常用的方法是交叉验证,它会返回数据集的分区以用于模型的训练和评估。

我的疑问是…

在这个代码片段中:

for train, test in kFold.split(X, Y):    model = None    model = create_model()    train_evaluate(model, X[train], Y[train], X[test], Y[test])

我们为每个生成的分区定义、训练和评估一个新的神经网络吗?

如果我的目标是为整个数据集微调网络,为什么定义一个单一的神经网络并用生成的分区来训练它是不正确的?

也就是说,为什么这段代码是这样的?

for train, test in kFold.split(X, Y):    model = None    model = create_model()    train_evaluate(model, X[train], Y[train], X[test], Y[test])

而不是这样?

model = Nonemodel = create_model()for train, test in kFold.split(X, Y):    train_evaluate(model, X[train], Y[train], X[test], Y[test])

是我对代码工作原理的理解错误了吗?还是我的理论有误?


回答:

如果我的目标是为整个数据集微调网络

你所说的“微调”是什么意思,或者你进行交叉验证(CV)的确切目的是什么并不清楚;一般来说,CV服务于以下目的之一:

  • 模型选择(选择超参数的值)
  • 模型评估

由于你的代码中没有定义任何超参数选择的搜索网格,看起来你使用CV是为了获得模型的预期性能(错误率、准确率等)。

无论你使用CV的理由是什么,第一个代码片段是正确的;你的第二个代码片段

model = Nonemodel = create_model()for train, test in kFold.split(X, Y):    train_evaluate(model, X[train], Y[train], X[test], Y[test])

将在不同的分区上顺序地训练你的模型(即在分区#1上训练,然后继续在分区#2上训练等),这本质上只是在你的整个数据集上进行训练,肯定不是交叉验证…

不过,在CV之后通常只暗示(并且经常被初学者忽略)的最后一步是,在你对通过CV程序给出的所选超参数和/或模型性能感到满意后,你会返回并再次训练你的模型,这次使用全部可用的数据。

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注