逐个样本训练Keras模型并绘制验证曲线

我看到了一些关于逐个样本训练Keras模型的帖子。从这些帖子中我了解到,Keras模型即使在逐个样本训练的基础上也会更新模型权重。现在的问题是,在这种情况下,我们如何提供验证分割,以及如何绘制模型的验证曲线,因为我们并不是将整个数据集一次性提供给模型,而是逐行提供数据集?

reduce_lr = tf.keras.callbacks.LearningRateScheduler(lambda x: 1e-3 * 0.90 ** x)model.compile(optimizer=keras.optimizers.Adam(clipnorm=1.0, clipvalue=0.5), loss="mse")for i in range(train_size):    one_row = train_data.iloc[i:i+1, :].to_numpy()    one_row = tf.convert_to_tensor(one_row .reshape(-1, one_row .shape[0], one_row.shape[1]), np.float32)    one_label=one_row    history= model.fit(one_row ,one_label, epochs=10, validation_split=0.1,verbose=2, callbacks=[reduce_lr])

回答:

Keras中,你有两种方式来进行验证:

  1. 第一种是指定一部分训练数据用于验证。这部分数据可以从数据集的末尾或随机抽取,具体取决于传递的Shuffle参数。可以通过指定validation_split来激活此选项:

介于0和1之间的浮点数。用于验证数据的训练数据的比例。模型将分离这部分训练数据,不在其上进行训练,并在每个epoch结束时评估此数据上的损失和任何模型指标。验证数据是从提供的x和y数据的最后几个样本中选取的,在打乱之前。

  1. 第二种选项是自己提供验证数据。你可以在最开始分割数据集,并取其中的一部分(通常为10%到20%)作为验证数据。这可以通过validation_data参数来指定:

用于在每个epoch结束时评估损失和任何模型指标的数据。模型不会在这数据上进行训练。因此,请注意使用validation_splitvalidation_data提供的数据的验证损失不受正则化层(如噪声和丢弃)的影响。validation_data将覆盖validation_splitvalidation_data可以是:

一个由Numpy数组或张量组成的元组(x_val, y_val)。一个由NumPy数组组成的元组(x_val, y_val, val_sample_weights)。一个tf.data.Dataset。一个返回(inputs, targets)或(inputs, targets, sample_weights)的Python生成器或keras.utils.Sequence。使用tf.distribute.experimental.ParameterServerStrategy时,validation_data尚不支持。

所以你可以选择上面的选项2,并在每次迭代中传递一个验证数据样本。然而,你需要定义如何补偿训练集和验证集之间样本数量差异的策略。例如,你可以将验证集的索引在到达验证集末尾时自动重置为0


关于如何绘制它们,这与通常的方式没有太大不同:

Model.fit()返回:

一个History对象。其History.history属性记录了在连续epoch中的训练损失值和指标值,以及验证损失值和验证指标值(如果适用)。

因此,如果你有一个epoch,每次迭代后从History对象中提取训练和验证损失/准确度/指标等,并将它们保存到相应的列表中。如果每次迭代有多个epoch,那么根据你的需求,取平均值或最后一个值等。

最后,在最后,使用例如经典的Matplotlib来绘制它们。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注