### 微调模型删除先前添加的层

我使用的是Keras 2.2.4版本。我训练了一个模型,希望每30个周期用新的数据内容(图像分类)进行微调。

每天我都会为模型添加更多图像到类中。每30个周期模型会被重新训练。我使用了两个条件,第一种情况是如果之前没有训练过的模型,第二种情况是当一个模型已经训练过后,我希望用新的内容/类别对其进行微调。

model_base = keras.applications.vgg19.VGG19(include_top=False, input_shape=(*IMG_SIZE, 3), weights='imagenet')
output = GlobalAveragePooling2D()(model_base.output)
# 如果我们要恢复一个预训练模型则加载它
if os.path.isfile(os.path.join(MODEL_PATH, 'weights.h5')):
    print('使用已有权重...')
    base_lr = 0.0001
    model = load_model(os.path.join(MODEL_PATH, 'weights.h5'))
    output = Dense(len(all_character_names), activation='softmax', name='d2')(output)
    model = Model(model_base.input, output)
    for layer in model_base.layers[:-2]:
        layer.trainable = False
else:
    base_lr = 0.001
    output = BatchNormalization()(output)
    output = Dropout(0.5)(output)
    output = Dense(2048, activation='relu', name='d1')(output)
    output = BatchNormalization()(output)
    output = Dropout(0.5)(output)
    output = Dense(len(all_character_names), activation='softmax', name='d2')(output)
    model = Model(model_base.input, output)
    for layer in model_base.layers[:-5]:
        layer.trainable = False
opt = optimizers.Adam(lr=base_lr, decay=base_lr / epochs)
model.compile(optimizer=opt,
            loss='categorical_crossentropy',
            metrics=['accuracy'])

第一次模型摘要:

..._________________________________________________________________
block5_conv4 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
global_average_pooling2d_1 ( (None, 512)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
d1 (Dense)                   (None, 2048)              1050624   
_________________________________________________________________
batch_normalization_2 (Batch (None, 2048)              8192      
_________________________________________________________________
dropout_2 (Dropout)          (None, 2048)              0         
_________________________________________________________________
d2 (Dense)                   (None, 19)                38931     
=================================================================
Total params: 21,124,179
Trainable params: 10,533,907
Non-trainable params: 10,590,272

第二次模型摘要:

..._________________________________________________________________
block5_conv4 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
global_average_pooling2d_1 ( (None, 512)               0         
_________________________________________________________________
d2 (Dense)                   (None, 19)                9747      
=================================================================
Total params: 20,034,131
Trainable params: 2,369,555
Non-trainable params: 17,664,576

问题:当一个模型存在并被加载以进行微调时,似乎失去了第一次添加的所有额外层(Dense 2048, Dropout等)。

我需要再次添加这些层吗?这似乎没有意义,因为它会丢失第一次训练的信息。

注意:我可能需要不设置base_lr,因为保存一个模型应该也保存它停止前的学习率状态,但这我稍后会检查。


回答:

请注意,一旦你加载了模型:

model = load_model(os.path.join(MODEL_PATH, 'weights.h5'))

你没有使用它。你只是再次覆盖它

model = Model(model_base.input, output)

其中output也被定义为对base_model的操作。在我看来,你只需要删除load_model之后的代码行。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注