model.fit()在每个epoch后是否会重置指标?如何手动重置指标?

据我所知,model.fit(epochs=NUM_EPOCHS)不会在每个epoch后重置指标。我的指标和model.fit()的代码如下(简化版):

import tensorflow as tffrom tensorflow.keras import applicationsNUM_CLASSES = 4INPUT_SHAPE = (256, 256, 3)MODELS = {    'DenseNet121': applications.DenseNet121,    'DenseNet169': applications.DenseNet169}REDUCE_LR_PATIENCE = 2REDUCE_LR_FACTOR = 0.7EARLY_STOPPING_PATIENCE = 4for modelName, model in MODELS.items():    loadedModel = model(include_top=False, weights='imagenet',                        pooling='avg', input_shape=INPUT_SHAPE)    sequentialModel = tf.keras.models.Sequential()    sequentialModel.add(loadedModel)    sequentialModel.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))    aucCurve = tf.keras.metrics.AUC(curve = 'ROC', multi_label = True)    categoricalAccuracy = tf.keras.metrics.CategoricalAccuracy()    F1Score  = tfa.metrics.F1Score(num_classes = NUM_CLASSES, average = 'macro', threshold = None)    metrics = [aucCurve, categoricalAccuracy, F1Score]    sequentialModel.compile(metrics=metrics)    callbacks = [    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', patience=REDUCE_LR_PATIENCE, verbose=1, factor=REDUCE_LR_FACTOR),    tf.keras.callbacks.EarlyStopping(monitor='val_loss', verbose=1, patience=EARLY_STOPPING_PATIENCE),    tf.keras.callbacks.ModelCheckpoint(filepath=modelName + '_epoch-{epoch:02d}.h5', monitor='val_loss', save_best_only=False, verbose=1),    tf.keras.callbacks.CSVLogger(modelName + '_training.csv')]    sequentialModel.fit(epochs=NUM_EPOCHS)

或许我可以通过在NUM_EPOCHS范围内进行for循环,并在循环中初始化指标来重置指标,但我并不确定这是否是一个好的解决方案。此外,我有ModelCheckpoint和CSVLogger回调,它们需要从model.fit()获取epoch编号,因此如果我使用for循环,这实际上是行不通的。

您对如何在每个epoch后重置指标有什么建议吗?在NUM_EPOCHS范围内进行for循环是这里唯一的解决方案吗?谢谢您。


回答:

不,指标是按每个epoch计算的。它们不是在epoch之间进行平均,而是每个epoch内的批次进行平均。你会看到指标在每个epoch后不断改善,因为你的模型正在被训练。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注