尝试用自定义飞机图片测试我的CIFAR-10训练的CNN

在学习了Python中Keras的CNN基础知识后,我尝试添加自己的飞机图片来测试我的训练程序。为此,我尝试导入了一张1920×1080的png图片,经过一些研究,我找到了可能的重塑图像的方法,但得到了以下错误消息:

Traceback (most recent call last): File “C:/Users/me/Desktop/Programming Courses/Image_Classifier_Project/Model_Test.py”, line 21, in img = np.reshape(img, [1, 32, 32, 3]) File “<array_function internals>”, line 6, in reshape File “C:\Users\me\AppData\Roaming\Python\Python37\site-packages\numpy\core\fromnumeric.py”, line 301, in reshape return _wrapfunc(a, ‘reshape’, newshape, order=order) File “C:\Users\me\AppData\Roaming\Python\Python37\site-packages\numpy\core\fromnumeric.py”, line 61, in _wrapfunc return bound(*args, **kwds) ValueError: cannot reshape array of size 1024 into shape (1,32,32,3)

该图片是彩色的(与训练图片一样)。

这是代码。我正在从文件中调用我的训练结果。

from keras.datasets import cifar10import keras.utils as utilsfrom keras.models import load_modelimport numpy as npimport cv2# Get Model Datalabels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'](_, _), (x_test, y_test) = cifar10.load_data()x_test = x_test.astype('float32') / 255.0y_test = utils.to_categorical(y_test)model = load_model('Classified.h5')img = cv2.imread("a400m.png", 0)img = cv2.resize(img, (32, 32))img = np.reshape(img, [1, 32, 32, 3])# results = model.evaluate(x=x_test, y=y_test)# print("Loss: ", results[0])# print("Accuracy", results[1])test_image_data = np.asarray(img)prediction = model.predict(x=test_image_data)print("Prediction: ", labels[np.argmax(prediction)])# max_index = np.argmax(prediction[0])# print("Prediction: ", labels[max_index])

抱歉代码有些混乱,只是尝试实现它,而不是从头开始

提前感谢!


回答:

首先,您的图片是彩色的,因此需要将其作为彩色图片加载:

img = cv2.imread("a400m.png", 1)  # 0表示灰度

其次:

img = cv2.resize(img, (32, 32)) #给出形状(32, 32, 3)

这行代码将把形状为(1920, 1080, 3)的图片重塑为形状(32, 32, 3)

最后,为了对该图片进行预测,您需要扩展其维度,为此使用numpy的expand_dim函数:

img = np.expand_dims(img, 0) #给出形状(1, 32, 32, 3),0表示第一维

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注