使用预训练的ResNet50网络构建OneClass SVM模型

我在尝试构建一个用于图像识别的OneClass分类器。我找到了这篇文章,但由于没有完整的源代码,我不太明白自己在做什么。

X_train, X_test, y_train, y_test = train_test_split(x, y, random_state=42)# X_train (2250, 200, 200, 3)resnet_model = ResNet50(input_shape=(200, 200, 3), weights='imagenet', include_top=False)features_array = resnet_model.predict(X_train)# features_array (2250, 7, 7, 2048)pca = PCA(svd_solver='randomized', n_components=450, whiten=True, random_state=42)svc = SVC(kernel='rbf', class_weight='balanced')model = make_pipeline(pca, svc)param_grid = {'svc__C': [1, 5, 10, 50], 'svc__gamma': [0.0001, 0.0005, 0.001, 0.005]}grid = GridSearchCV(model, param_grid)grid.fit(X_train, y_train)

我有2250张图像(食物和非食物),尺寸为200×200像素,我将这些数据发送到ResNet50模型的predict方法。结果是一个(2250, 7, 7, 2048)的张量,有人知道这个维度是什么意思吗?

当我尝试运行grid.fit方法时,我得到了一个错误:

ValueError: Found array with dim 4. Estimator expected <= 2.

回答:

这些是我能得出的发现。

你得到的是全局平均池化层之上的输出张量。(查看resnet_model.summary()以了解输入维度如何变化到输出维度)

为了简单修复这个问题,可以在resnet_model上添加一个二维平均池化层。(这样输出形状将变为(2250,1,1, 2048))

resnet_model = ResNet50(input_shape=(200, 200, 3), weights='imagenet', include_top=False)resnet_op = AveragePooling2D((7, 7), name='avg_pool_app')(resnet_model.output)resnet_model = Model(resnet_model.input, resnet_op, name="ResNet")

这通常在ResNet50的源代码中已经存在。基本上,我们是在resnet50模型上附加一个AveragePooling2D层。最后一行将层(第二行)和基础模型结合成一个模型对象。

现在,输出维度(feature_array)将是(2250, 1, 1, 2048)(因为添加了平均池化层)。

为了避免ValueError,你需要将这个feature_array重塑为(2250, 2048)

feature_array = np.reshape(feature_array, (-1, 2048))

在问题中程序的最后一行,

grid.fit(X_train, y_train)

你使用了X_train(在这种情况下是图像)进行拟合。这里的正确变量应该是features_array(这被认为是图像的摘要)。输入这一行将纠正错误,

grid.fit(features_array, y_train)

要通过提取特征向量进行更多此类微调,请查看这里(使用神经网络进行训练,而不是使用PCA和SVM)。

希望这对你有帮助!!

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注