加载保存的模型以恢复训练

我正在训练一个ResNet模型来分类汽车品牌。

我在每次训练周期保存了模型的权重。

作为测试,我在第三个周期停止了训练。

# checkpoint = ModelCheckpoint("best_model.hdf5", monitor='loss', verbose=1)checkpoint_path = "weights/cp-{epoch:04d}.ckpt"checkpoint_dir = os.path.dirname(checkpoint_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(   checkpoint_path, verbose=1,    # Save weights, every epoch.   save_freq='epoch')model.save_weights(checkpoint_path.format(epoch=0))history = model.fit_generator(    training_set,    validation_data = test_set,    epochs = 50,    steps_per_epoch = len(training_set),    validation_steps = len(test_set),    callbacks = [cp_callback])

然而,当我加载这些权重时,我不确定是否从最后保存的周期恢复,因为它显示的仍然是周期1/50。以下是我用来加载最后保存的模型的代码。

from keras.models import Sequential, load_model# load the modelnew_model = load_model('./weights/cp-0003.ckpt')# fit the modelhistory = new_model.fit_generator(    training_set,    validation_data = test_set,    epochs = 50,    steps_per_epoch = len(training_set),    validation_steps = len(test_set),    callbacks = [cp_callback])

看起来是这样的:图像显示运行保存的权重从周期1/50重新开始

请问有人可以帮忙吗?


回答:

你可以使用 fit_generatorinitial_epoch 参数。默认情况下,它被设置为0,但你可以将其设置为任何正数:

from keras.models import Sequential, load_modelimport tensorflow as tfcheckpoint_path = "weights/cp-{epoch:04d}.ckpt"checkpoint_dir = os.path.dirname(checkpoint_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(   checkpoint_path, verbose=1,    # Save weights, every epoch.   save_freq='epoch')model.save_weights(checkpoint_path.format(epoch=0))history = model.fit_generator(    training_set,    validation_data=test_set,    epochs=3,    steps_per_epoch=len(training_set),    validation_steps=len(test_set),    callbacks = [cp_callback])new_model = load_model('./weights/cp-0003.ckpt')# fit the modelhistory = new_model.fit_generator(    training_set,    validation_data=test_set,    epochs=50,    steps_per_epoch=len(training_set),    validation_steps=len(test_set),    callbacks=[cp_callback],    initial_epoch=3)

这将使你的模型额外训练50 – 3 = 47个周期。


关于你的代码,如果你使用Tensorflow 2.X,以下是一些备注:

  • fit_generator 自从 fit 支持生成器后已被废弃
  • 你应该将你的导入从 from keras.... 替换为 from tensorflow.keras...

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注