从预训练模型中移除层的-issue

我有以下代码,需要移除模型的一些层并进行预测。但目前我遇到了错误。

 from tensorflow.keras.applications.resnet50 import ResNet50 from tensorflow.keras.preprocessing import image from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions import numpy as np from keras.models import Model from tensorflow.python.keras.optimizers import SGD base_model = ResNet50(include_top=False, weights='imagenet') model= Model(inputs=base_model.input, outputs=base_model .layers[-2].output) #model = Model(inputs=base_model.input, outputs=predictions) #Compiling the model model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics =  ['accuracy']) img_path = 'elephant.jpg' img = image.load_img(img_path, target_size=(224, 224)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) preds = model.predict(x) #decode the results into a list of tuples (class, description, probability) #(one such list for each sample in the batch) print('Predicted:', decode_predictions(preds, top=3)[0])

错误

File "C:/Users/learn/remove_layer.py", line 9, in <module>model= Model(inputs=base_model.input, outputs=base_model .layers[-2].output)AttributeError: 'Tensor' object has no attribute '_keras_shape'

由于我在Keras方面的初学者知识,我理解这是形状问题。因为这是ResNet模型,如果我从一个合并层移除一层到另一个合并层,因为合并层没有维度问题,我该如何完成这个操作?


回答:

你实际上需要可视化你所做的操作,所以让我们对ResNet50模型的最后几层做一个小总结:

base_model.summary()conv5_block3_2_relu (Activation (None, None, None, 5 0           conv5_block3_2_bn[0][0]          __________________________________________________________________________________________________conv5_block3_3_conv (Conv2D)    (None, None, None, 2 1050624     conv5_block3_2_relu[0][0]        __________________________________________________________________________________________________conv5_block3_3_bn (BatchNormali (None, None, None, 2 8192        conv5_block3_3_conv[0][0]        __________________________________________________________________________________________________conv5_block3_add (Add)          (None, None, None, 2 0           conv5_block2_out[0][0]                                                                            conv5_block3_3_bn[0][0]          __________________________________________________________________________________________________conv5_block3_out (Activation)   (None, None, None, 2 0           conv5_block3_add[0][0]           ==================================================================================================Total params: 23,587,712Trainable params: 23,534,592Non-trainable params: 53,120_____________________________

现在是移除最后一层后的模型

model.summary()conv5_block3_2_relu (Activation (None, None, None, 5 0           conv5_block3_2_bn[0][0]          __________________________________________________________________________________________________conv5_block3_3_conv (Conv2D)    (None, None, None, 2 1050624     conv5_block3_2_relu[0][0]        __________________________________________________________________________________________________conv5_block3_3_bn (BatchNormali (None, None, None, 2 8192        conv5_block3_3_conv[0][0]        __________________________________________________________________________________________________conv5_block3_add (Add)          (None, None, None, 2 0           conv5_block2_out[0][0]                                                                            conv5_block3_3_bn[0][0]          ==================================================================================================Total params: 23,587,712Trainable params: 23,534,592Non-trainable params: 53,120

Keras中的ResNet50输出是最后Conv2D块之后的所有特征图,它不关心模型的分类部分,你实际做的只是移除了最后一个加法块之后的激活层

enter image description here

所以你需要检查你想移除的具体块层,并为分类部分添加flatten和全连接层

正如Dr.Snoopy提到的,不要混合使用keras和tensorflow.keras的导入

# 这一部分from tensorflow.keras.models import Model

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注