如何重置Keras指标?

为了进行一些参数调优,我喜欢对Keras的训练函数进行循环操作。然而,我发现当使用tensorflow.keras.metrics.AUC()作为指标时,每次训练循环都会在auc指标名称上增加一个整数(例如,auc_1, auc_2, …)。因此,实际上即使退出训练函数,Keras指标仍然被存储。

首先,这导致回调函数无法再识别该指标,同时也让我怀疑是否还有其他东西被存储,比如模型权重。

我如何重置这些指标?还有其他由Keras存储的东西需要我重置,以获得训练的彻底重启吗?

以下是一个最小工作示例:

编辑:这个示例似乎只适用于TensorFlow 2.2

输出结果为:

 **** Loop 0 **** 2020-06-16 14:37:46.621264: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f991e541f10 initialized for platform Host (this does not guarantee that XLA will be used). Devices:2020-06-16 14:37:46.621296: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default VersionEpoch 1/210/10 [==============================] - 0s 44ms/step - loss: 0.1295 - auc: 0.0000e+00 - val_loss: 0.0310 - val_auc: 0.0000e+00 - lr: 0.0010Epoch 2/210/10 [==============================] - 0s 10ms/step - loss: 0.0262 - auc: 0.0000e+00 - val_loss: 0.0223 - val_auc: 0.0000e+00 - lr: 0.0010 **** Loop 1 **** Epoch 1/210/10 [==============================] - ETA: 0s - loss: 0.4751 - auc_1: 0.0000e+00WARNING:tensorflow:Reduce LR on plateau conditioned on metric `val_auc` which is not available. Available metrics are: loss,auc_1,val_loss,val_auc_1,lrWARNING:tensorflow:Early stopping conditioned on metric `val_auc` which is not available. Available metrics are: loss,auc_1,val_loss,val_auc_1,lr10/10 [==============================] - 0s 36ms/step - loss: 0.4751 - auc_1: 0.0000e+00 - val_loss: 0.3137 - val_auc_1: 0.0000e+00 - lr: 0.0010Epoch 2/210/10 [==============================] - ETA: 0s - loss: 0.2617 - auc_1: 0.0000e+00WARNING:tensorflow:Reduce LR on plateau conditioned on metric `val_auc` which is not available. Available metrics are: loss,auc_1,val_loss,val_auc_1,lrWARNING:tensorflow:Early stopping conditioned on metric `val_auc` which is not available. Available metrics are: loss,auc_1,val_loss,val_auc_1,lr10/10 [==============================] - 0s 10ms/step - loss: 0.2617 - auc_1: 0.0000e+00 - val_loss: 0.2137 - val_auc_1: 0.0000e+00 - lr: 0.0010 **** Loop 2 **** Epoch 1/210/10 [==============================] - ETA: 0s - loss: 0.1948 - auc_2: 0.0000e+00WARNING:tensorflow:Reduce LR on plateau conditioned on metric `val_auc` which is not available. Available metrics are: loss,auc_2,val_loss,val_auc_2,lrWARNING:tensorflow:Early stopping conditioned on metric `val_auc` which is not available. Available metrics are: loss,auc_2,val_loss,val_auc_2,lr10/10 [==============================] - 0s 34ms/step - loss: 0.1948 - auc_2: 0.0000e+00 - val_loss: 0.0517 - val_auc_2: 0.0000e+00 - lr: 0.0010Epoch 2/210/10 [==============================] - ETA: 0s - loss: 0.0445 - auc_2: 0.0000e+00WARNING:tensorflow:Reduce LR on plateau conditioned on metric `val_auc` which is not available. Available metrics are: loss,auc_2,val_loss,val_auc_2,lrWARNING:tensorflow:Early stopping conditioned on metric `val_auc` which is not available. Available metrics are: loss,auc_2,val_loss,val_auc_2,lr10/10 [==============================] - 0s 10ms/step - loss: 0.0445 - auc_2: 0.0000e+00 - val_loss: 0.0389 - val_auc_2: 0.0000e+00 - lr: 0.0010

回答:

您的可重现示例在我这里有几个地方失败了,所以我只做了一些小的改动(我使用的是TF 2.1)。在运行后,我通过指定metrics=[AUC(name='auc')]成功消除了额外的指标名称。以下是完整的(已修复的)可重现示例:

Train for 10 steps, validate for 10 stepsEpoch 1/2 1/10 [==>...........................] - ETA: 6s - loss: 0.3426 - auc: 0.4530 7/10 [====================>.........] - ETA: 0s - loss: 0.3318 - auc: 0.489510/10 [==============================] - 1s 117ms/step - loss: 0.3301 -                                          auc: 0.4893 - val_loss: 0.3222 - val_auc: 0.5085

这是因为每次循环,您通过metrics=[AUC()]创建了一个没有指定名称的新指标。在循环的第一次迭代中,TF自动在名称空间中创建了一个名为auc的变量,但在循环的第二次迭代中,名称'auc'已经被占用,因此由于您没有指定名称,TF将其命名为auc_1。但是,您的回调函数设置为基于auc,而这是该模型没有的指标(它是前一个循环模型的指标)。因此,您可以使用name='auc'覆盖之前的指标名称,或者像这样在循环之外定义它:

不用担心keras会重置指标。所有这些都在fit()方法中处理。如果您想要更多灵活性和/或自己操作,我建议使用自定义训练循环,并自己重置它:

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注