如何在Keras中查看提前停止的最佳轮次的损失?

我已经成功地在我的Keras模型中实现了提前停止,但是我不确定如何查看最佳轮次的损失。

es = EarlyStopping(monitor='val_out_soft_loss',             mode='min',            restore_best_weights=True,             verbose=2,             patience=10)model.fit(tr_x,          tr_y,          batch_size=batch_size,          epochs=epochs,          verbose=1,          callbacks=[es],          validation_data=(val_x, val_y))loss = model.history.history["val_out_soft_loss"][-1]return model, loss

我定义损失分数的方式意味着返回的分数来自最后一个轮次,而不是最佳轮次。

示例:

from sklearn.model_selection import train_test_split, KFoldlosses = []models = []for k in range(2):    kfold = KFold(5, random_state = 42 + k, shuffle = True)    for k_fold, (tr_inds, val_inds) in enumerate(kfold.split(train_y)):        print("-----------")        print("-----------")        model, loss = get_model(64, 100)        models.append(model)        print(k_fold, loss)        losses.append(loss)print("-------")print(losses)print(np.mean(losses))Epoch 23/10018536/18536 [==============================] - 7s 362us/step - loss: 0.0116 - out_soft_loss: 0.0112 - out_reg_loss: 0.0393 - val_loss: 0.0131 - val_out_soft_loss: 0.0127 - val_out_reg_loss: 0.0381Epoch 24/10018536/18536 [==============================] - 7s 356us/step - loss: 0.0116 - out_soft_loss: 0.0112 - out_reg_loss: 0.0388 - val_loss: 0.0132 - val_out_soft_loss: 0.0127 - val_out_reg_loss: 0.0403Restoring model weights from the end of the best epochEpoch 00024: early stopping0 0.012735568918287754

因此,在这个示例中,我想查看第14轮次(Epoch 00014)的损失(0.0124)。

我还有一个单独的问题:如何设置val_out_soft_loss分数的小数位数?


回答:

将Keras中的fit()调用赋值给一个变量,以便你可以跟踪各个轮次的度量标准。

history = model.fit(tr_x, ...

它会返回一个字典,像这样访问它:

loss_hist = history.history['loss']

然后使用min()获取最小损失,使用argmin()获取最佳轮次(基于零的索引)。

np.min(loss_hist)np.argmin(loss_hist)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注