在Keras中进行交叉验证

我在Keras中实现了一个多层感知器,并使用scikit-learn来进行交叉验证。为此,我受到了在Keras中进行交叉验证问题中找到的代码的启发

from sklearn.cross_validation import StratifiedKFolddef load_data():    # 使用此函数加载数据def create model():    # 使用此函数创建模型def train_and_evaluate__model(model, data[train], labels[train], data[test], labels[test]):    # 在此处进行拟合和评估if __name__ == "__main__":    X, Y = load_model()    kFold = StratifiedKFold(n_splits=10)    for train, test in kFold.split(X, Y):        model = None        model = create_model()        train_evaluate(model, X[train], Y[train], X[test], Y[test])

在我的神经网络研究中,我了解到神经网络的知识表示存在于突触权重中,并且在网络追踪过程中,更新权重以减少网络错误率并提高其性能。(在我的情况下,我使用的是监督学习)

为了更好地训练和评估神经网络的性能,常用的方法是交叉验证,它会返回数据集的分区以用于模型的训练和评估。

我的疑问是…

在这个代码片段中:

for train, test in kFold.split(X, Y):    model = None    model = create_model()    train_evaluate(model, X[train], Y[train], X[test], Y[test])

我们为每个生成的分区定义、训练和评估一个新的神经网络吗?

如果我的目标是为整个数据集微调网络,为什么定义一个单一的神经网络并用生成的分区来训练它是不正确的?

也就是说,为什么这段代码是这样的?

for train, test in kFold.split(X, Y):    model = None    model = create_model()    train_evaluate(model, X[train], Y[train], X[test], Y[test])

而不是这样?

model = Nonemodel = create_model()for train, test in kFold.split(X, Y):    train_evaluate(model, X[train], Y[train], X[test], Y[test])

是我对代码工作原理的理解错误了吗?还是我的理论有误?


回答:

如果我的目标是为整个数据集微调网络

你所说的“微调”是什么意思,或者你进行交叉验证(CV)的确切目的是什么并不清楚;一般来说,CV服务于以下目的之一:

  • 模型选择(选择超参数的值)
  • 模型评估

由于你的代码中没有定义任何超参数选择的搜索网格,看起来你使用CV是为了获得模型的预期性能(错误率、准确率等)。

无论你使用CV的理由是什么,第一个代码片段是正确的;你的第二个代码片段

model = Nonemodel = create_model()for train, test in kFold.split(X, Y):    train_evaluate(model, X[train], Y[train], X[test], Y[test])

将在不同的分区上顺序地训练你的模型(即在分区#1上训练,然后继续在分区#2上训练等),这本质上只是在你的整个数据集上进行训练,肯定不是交叉验证…

不过,在CV之后通常只暗示(并且经常被初学者忽略)的最后一步是,在你对通过CV程序给出的所选超参数和/或模型性能感到满意后,你会返回并再次训练你的模型,这次使用全部可用的数据。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注