如何使用预训练模型进行双输入迁移学习

我打算使用一个预训练模型(之前使用ModelCheckpointsave_best_only参数保存)进行双输入迁移学习。我有以下内容:

pretrained_model = load_model('best_weight.h5')def combined_net():         u_model = pretrained_model    u_output = u_model.layers[-1].output        v_model = pretrained_model    v_output = v_model.layers[-1].output    concat = concatenate([u_output, v_output])    #hidden1 = Dense(64, activation=activation)(concat) #was 128    main_output = Dense(1, activation='sigmoid', name='main_output')(concat) # pretrained_model.get_layer("input_1").input    model = Model(inputs=[u_model.input, v_model.input], outputs=main_output)    opt = SGD(lr=0.001, nesterov=True)    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])    return model

当我尝试使用以下代码进行拟合时:

best_weights_file="weights_best_of_pretrained_dual.hdf5"checkpoint = ModelCheckpoint(best_weights_file, monitor='val_acc', verbose=1, save_best_only=True, mode='max')callbacks = [checkpoint]base_model = combined_net()print(base_model.summary)history = base_model.fit([x_train_u, x_train_v], y_train,                         batch_size=batch_size,                         epochs=epochs,                         callbacks=callbacks,                          verbose=1,                         validation_data=([x_test_u, x_test_v], y_test),                          shuffle=True)

我遇到了以下错误:

ValueError: The list of inputs passed to the model is redundant. All inputs should only appear once. Found: [<tf.Tensor 'input_1_5:0' shape=(None, None, None, 3) dtype=float32>, <tf.Tensor 'input_1_5:0' shape=(None, None, None, 3) dtype=float32>]

显然,model = Model(inputs=[u_model.input, v_model.input], outputs=main_output)这行代码似乎引起了错误。

我只想使用预训练模型(“best_weight.h5”)来构建双输入单输出的模型。两个输入与之前初始化的相同,并且concatenate层应该在每个由加载的模型构建的模型的最后一层之前进行连接。

我在网上尝试了几种方法,但无法正确设置模型。

希望有人能帮助我

编辑:

预训练模型如下所示:

def vgg_16():    b_model = VGG16(weights='imagenet', include_top=False)    x = b_model.output    x = GlobalAveragePooling2D()(x)    x = Dense(256, activation=activation)(x)    predictions = Dense(1, activation='sigmoid')(x)    model = Model(inputs=b_model.input, outputs=predictions)    for layer in model.layers[:15]:  #        layer.trainable = False    opt = SGD(lr=init_lr, nesterov=True)    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])    return modelmain_model = vgg_16()history = main_model.fit(X_train, y_train, batch_size=batch_size,           epochs=EPOCHS, validation_data=(X_test, y_test), verbose=1,           callbacks=[es, mc, l_r])

回答:

这是正确的方法。当我定义combined_net时,我定义了两个新的输入,这些输入以相同的方式馈送到pre_trained模型中

def vgg_16():        b_model = tf.keras.applications.VGG16(weights='imagenet', include_top=False)    x = b_model.output    x = GlobalAveragePooling2D()(x)    x = Dense(256, activation='relu')(x)    predictions = Dense(1, activation='sigmoid')(x)    model = Model(inputs=b_model.input, outputs=predictions)        for layer in model.layers[:15]:        layer.trainable = False            opt = SGD(lr=0.003, nesterov=True)    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])        return modelmain_model = vgg_16()# main_model.fit(...)pretrained_model = Model(main_model.input, main_model.layers[-2].output)def combined_net():         inp_u = Input((224,224,3)) # 与预训练模型的输入维度相同    inp_v = Input((224,224,3)) # 与预训练模型的输入维度相同        u_output = pretrained_model(inp_u)    v_output = pretrained_model(inp_v)    concat = concatenate([u_output, v_output])    main_output = Dense(1, activation='sigmoid', name='main_output')(concat)    model = Model(inputs=[inp_u, inp_v], outputs=main_output)    opt = SGD(lr=0.001, nesterov=True)    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])    return modelbase_model = combined_net()base_model.summary()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注