使用Keras回调函数无法保存模型

我正在处理一个图像识别问题。我的模型需要训练200个周期。我希望在每个周期结束时,如果模型的验证准确率是迄今为止最好的,就保存该模型。这是我的代码,

from keras.models import Sequentialfrom keras.models import Modelfrom keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau, TensorBoardfrom keras import optimizers, losses, activations, modelsfrom keras.layers import Convolution2D, Dense, Input, Flatten, Dropout, MaxPooling2D, BatchNormalization, GlobalAveragePooling2D, Concatenatefrom keras import applicationsfrom keras import backend as Kfrom keras import callbacksfrom keras.preprocessing.image import ImageDataGeneratorROWS,COLS = 669,1026input_shape = (ROWS, COLS, 3)base_model = applications.VGG19(weights='imagenet',                                 include_top=False,                                 input_shape=(ROWS, COLS,3))l = 0for layer in base_model.layers:    layer.trainable = False    l += 1c = 0for layer in base_model.layers:    c += 1    if c > l-5:        layer.trainable = True for layer in base_model.layers:    print(layer,layer.trainable)base_model.summary()add_model = Sequential()add_model.add(base_model)add_model.add(GlobalAveragePooling2D())add_model.add(Dense(514, activation='relu'))add_model.add(Dense(128, activation='relu'))add_model.add(Dense(64, activation='relu'))add_model.add(Dropout(0.5))add_model.add(Dense(8, activation='relu'))add_model.add(Dropout(0.5))add_model.add(Dense(1, activation='sigmoid'))model = add_model# model.compile(loss='binary_crossentropy', #               optimizer=optimizers.SGD(lr=1e-, #                                        momentum=0.9),#               metrics=['accuracy'])model.compile(loss='binary_crossentropy',               optimizer='adam',              metrics=['accuracy'])model.summary()train_data_dir = '/home/spectrograms/train'validation_data_dir = '/home/spectrograms/test'nb_train_samples = 791nb_validation_samples = 198epochs = 200batch_size = 3if K.image_data_format() == 'channels_first':    input_shape = (3, ROWS, COLS)else:    input_shape = (ROWS, COLS,3)# this is the augmentation configuration we will use for trainingtrain_datagen = ImageDataGenerator(    rescale=1. / 255,    shear_range=0,    zoom_range=0,    horizontal_flip=False)test_datagen = ImageDataGenerator(rescale=1. / 255)train_generator = train_datagen.flow_from_directory(    train_data_dir,    target_size=(ROWS, COLS),    batch_size=batch_size,    class_mode='binary')validation_generator = test_datagen.flow_from_directory(    validation_data_dir,    target_size=(ROWS, COLS),    batch_size=batch_size,    class_mode='binary')checkpoint_filepath = '/home/CNN/saved_model/checkpoints/checkpoint-{epoch:02d}-{val_loss:.2f}.h5'model_checkpoint_callback = callbacks.ModelCheckpoint(    filepath=checkpoint_filepath,    save_weights_only=False,    monitor='val_accuracy',    mode='max',    save_best_only=True)model.fit_generator(    train_generator,    steps_per_epoch=nb_train_samples // batch_size,    epochs=epochs,    callbacks = [model_checkpoint_callback],    validation_data=validation_generator,    validation_steps=nb_validation_samples // batch_size)

但是我收到了错误OSError: Unable to create file (unable to open file: name = '/home/CNN/saved_model/checkpoints/checkpoint-01-0.69.h5', errno = 2, error message = 'No such file or directory', flags = 13, o_flags = 242)

但实际上该文件目录是存在的。我不明白这里的问题出在哪里。


回答:

在您的代码中添加以下内容

save_dir = '/home/CNN/saved_model/checkpoints/'if not os.path.exists(save_dir):   os.makedirs(save_dir)checkpoint_filepath = os.path.join(save_dir, "checkpoint-{epoch:02d}-{val_loss:.2f}.h5")model_checkpoint_callback = callbacks.ModelCheckpoint(filepath=checkpoint_filepath,save_weights_only=False,monitor='val_accuracy',mode='max',save_best_only=True) 

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注