无法使用Keras中的VGG19对单张图像进行标签预测

我正在按照[这个教程](https://towardsdatascience.com/keras-transfer-learning-for-beginners-6c9b8b7143e)使用迁移学习方法来使用Keras中的预训练VGG19模型。教程展示了如何训练模型,但没有展示如何为预测准备测试图像。

在评论部分中提到:

获取一张图像,使用相同的preprocess_image函数预处理图像,然后调用model.predict(image)。这将给出模型对该图像的预测。使用argmax(prediction),你可以找到图像所属的类别。

我无法在代码中找到名为preprocess_image的函数。我进行了一些搜索,并考虑使用这个教程提出的方法。

但这导致了一个错误,提示如下:

decode_predictions expects a batch of predictions (i.e. a 2D array of shape (samples, 1000)). Found array with shape: (1, 12)

我的数据集有12个类别。以下是训练模型的完整代码以及我如何得到这个错误的:

import pandas as pdimport numpy as npimport osimport kerasimport matplotlib.pyplot as pltfrom keras.layers import Dense, GlobalAveragePooling2Dfrom keras.applications.vgg19 import VGG19from keras.preprocessing import imagefrom keras.applications.vgg19 import preprocess_inputfrom keras.preprocessing.image import ImageDataGeneratorfrom keras.models import Modelfrom keras.optimizers import Adambase_model = VGG19(weights='imagenet', include_top=False)x=base_model.output                                                          x=GlobalAveragePooling2D()(x)                                                x=Dense(1024,activation='relu')(x)                                           x=Dense(1024,activation='relu')(x)                                           x=Dense(512,activation='relu')(x)        preds=Dense(12,activation='softmax')(x)                                      model=Model(inputs=base_model.input,outputs=preds)                           # view the layer architecture# for i,layer in enumerate(model.layers):#   print(i,layer.name)for layer in model.layers:    layer.trainable=Falsefor layer in model.layers[:20]:    layer.trainable=Falsefor layer in model.layers[20:]:    layer.trainable=Truetrain_datagen=ImageDataGenerator(preprocessing_function=preprocess_input)train_generator=train_datagen.flow_from_directory('dataset',                    target_size=(96,96), # 224, 224                    color_mode='rgb',                    batch_size=64,                    class_mode='categorical',                    shuffle=True)model.compile(optimizer='Adam',loss='categorical_crossentropy',metrics=['accuracy'])step_size_train=train_generator.n//train_generator.batch_sizemodel.fit_generator(generator=train_generator,    steps_per_epoch=step_size_train,    epochs=5)# model.predict(new_image)

IPython:

In [3]: import classify_tl                                                                                                                                                   Found 4750 images belonging to 12 classes.Epoch 1/574/74 [==============================] - 583s 8s/step - loss: 2.0113 - acc: 0.4557Epoch 2/574/74 [==============================] - 576s 8s/step - loss: 0.8222 - acc: 0.7170Epoch 3/574/74 [==============================] - 563s 8s/step - loss: 0.5875 - acc: 0.7929Epoch 4/574/74 [==============================] - 585s 8s/step - loss: 0.3897 - acc: 0.8627Epoch 5/574/74 [==============================] - 610s 8s/step - loss: 0.2689 - acc: 0.9071In [6]: model = classify_tl.model                                                                                                                                            In [7]: print(model)                                                                                                                                                         <keras.engine.training.Model object at 0x7fb3ad988518>In [8]: from keras.preprocessing.image import load_img                                                                                                                       In [9]: image = load_img('examples/0021e90e4.png', target_size=(96,96))                                                                                                      In [10]: from keras.preprocessing.image import img_to_array                                                                                                                  In [11]: image = img_to_array(image)                                                                                                                                         In [12]: image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))                                                                                          In [13]: from keras.applications.vgg19 import preprocess_input                                                                                                               In [14]: image = preprocess_input(image)                                                                                                                                     In [15]: yhat = model.predict(image)                                                                                                                                         In [16]: print(yhat)                                                                                                                                                         [[1.3975363e-06 3.1069856e-05 9.9680350e-05 1.7175063e-03 6.2767825e-08  2.6133494e-03 7.2859187e-08 6.0187017e-07 2.0794137e-06 1.3714411e-03  9.9416250e-01 2.6067207e-07]]In [17]: from keras.applications.vgg19 import decode_predictions                                                                                                             In [18]: label = decode_predictions(yhat) 

IPython提示中的最后一行导致了以下错误:

ValueError: `decode_predictions` expects a batch of predictions (i.e. a 2D array of shape (samples, 1000)). Found array with shape: (1, 12)

我应该如何正确地输入我的测试图像并获得预测结果?


回答:

decode_predictions用于解码模型在ImageNet数据集上的预测,该数据集有1000个类别。然而,你的微调模型只有12个类别。因此,这里使用decode_predictions没有意义。当然,你一定知道这12个类别的标签。因此,只需在预测中找到最大分数的索引并找到其标签:

# 创建一个包含类别标签的列表class_labels = ['class1', 'class2', 'class3', ...., 'class12']# 找到最大分数类别的索引pred = np.argmax(class_labels, axis=-1)# 打印最大分数类别的标签print(class_labels[pred[0]])

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注