我利用一些空闲时间快速学习了一些Python和Keras。我创建了一个图像集,其中包含4050张类别a(三叶草)的图片和2358张类别b(草地)的图片。由于可能还会增加其他类别,所以我没有选择二元分类模式。
这些图片按类别分别存放在不同的子文件夹中,我随机将其分为70%的训练数据和30%的测试数据,并按照相应的文件夹结构进行组织。目前,训练和测试数据尚未进行归一化处理。
我训练了模型并保存了结果。训练准确率大约为90%。然而,当我尝试预测单张图片时(这是所需的用例),预测的平均准确率约为64%,这与整体类别a图片的百分比(4050 / (4050+2358) = ~63%)非常接近。为了进行这个测试,我使用了实际数据集中随机选择的图片,但即使是使用真正的新数据,结果也同样不理想。观察预测结果,模型大多数情况下预测为类别a,偶尔预测为类别b。为什么会这样?我不知道哪里出了问题。您能帮我看看吗?
模型的构建如下:
epochs = 50IMG_HEIGHT = 50IMG_WIDTH = 50train_image_generator = ImageDataGenerator( rescale=1./255, rotation_range=45, width_shift_range=.15, height_shift_range=.15, horizontal_flip=True, zoom_range=0.1)validation_image_generator = ImageDataGenerator(rescale=1./255)train_path = os.path.join(global_dir,"Train")validate_path = os.path.join(global_dir,"Validate")train_data_gen = train_image_generator.flow_from_directory(directory=train_path, shuffle=True, target_size=(IMG_HEIGHT, IMG_WIDTH), class_mode='categorical')val_data_gen = validation_image_generator.flow_from_directory(directory=validate_path, shuffle=True, target_size=(IMG_HEIGHT, IMG_WIDTH), class_mode='categorical')model = Sequential([ Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)), MaxPooling2D(), Conv2D(32, 3, padding='same', activation='relu'), MaxPooling2D(), Dropout(0.2), Conv2D(64, 3, padding='same', activation='relu'), MaxPooling2D(), Dropout(0.2), Flatten(), Dense(512, activation='relu'), Dense(64, activation='relu'), Dense(2, activation='softmax') ])model.compile(optimizer='adam', loss=keras.losses.categorical_crossentropy, metrics=['accuracy'])model.summary()history = model.fit( train_data_gen, batch_size=200, epochs=epochs, validation_data=val_data_gen)model.save(global_dir + "/Model/1)
训练输出的结果如下:
Model: "sequential"_________________________________________________________________Layer (type) Output Shape Param # =================================================================conv2d (Conv2D) (None, 50, 50, 16) 448 _________________________________________________________________max_pooling2d (MaxPooling2D) (None, 25, 25, 16) 0 _________________________________________________________________conv2d_1 (Conv2D) (None, 25, 25, 32) 4640 _________________________________________________________________max_pooling2d_1 (MaxPooling2 (None, 12, 12, 32) 0 _________________________________________________________________dropout (Dropout) (None, 12, 12, 32) 0 _________________________________________________________________conv2d_2 (Conv2D) (None, 12, 12, 64) 18496 _________________________________________________________________max_pooling2d_2 (MaxPooling2 (None, 6, 6, 64) 0 _________________________________________________________________dropout_1 (Dropout) (None, 6, 6, 64) 0 _________________________________________________________________flatten (Flatten) (None, 2304) 0 _________________________________________________________________dense (Dense) (None, 512) 1180160 _________________________________________________________________dense_1 (Dense) (None, 64) 32832 _________________________________________________________________dense_2 (Dense) (None, 2) 130 =================================================================Total params: 1,236,706Trainable params: 1,236,706Non-trainable params: 0_________________________________________________________________Epoch 1/50141/141 [==============================] - 14s 102ms/step - loss: 0.6216 - accuracy: 0.6468 - val_loss: 0.5396 - val_accuracy: 0.7120Epoch 2/50141/141 [==============================] - 12s 86ms/step - loss: 0.5129 - accuracy: 0.7488 - val_loss: 0.4427 - val_accuracy: 0.8056Epoch 3/50141/141 [==============================] - 12s 86ms/step - loss: 0.4917 - accuracy: 0.7624 - val_loss: 0.5004 - val_accuracy: 0.7705Epoch 4/50141/141 [==============================] - 15s 104ms/step - loss: 0.4510 - accuracy: 0.7910 - val_loss: 0.4226 - val_accuracy: 0.8198Epoch 5/50141/141 [==============================] - 12s 85ms/step - loss: 0.4056 - accuracy: 0.8219 - val_loss: 0.3439 - val_accuracy: 0.8514Epoch 6/50141/141 [==============================] - 12s 84ms/step - loss: 0.3904 - accuracy: 0.8295 - val_loss: 0.3207 - val_accuracy: 0.8646Epoch 7/50141/141 [==============================] - 12s 85ms/step - loss: 0.3764 - accuracy: 0.8304 - val_loss: 0.3185 - val_accuracy: 0.8702Epoch 8/50141/141 [==============================] - 12s 87ms/step - loss: 0.3695 - accuracy: 0.8362 - val_loss: 0.2958 - val_accuracy: 0.8743Epoch 9/50141/141 [==============================] - 12s 84ms/step - loss: 0.3455 - accuracy: 0.8574 - val_loss: 0.3096 - val_accuracy: 0.8687Epoch 10/50141/141 [==============================] - 12s 84ms/step - loss: 0.3483 - accuracy: 0.8473 - val_loss: 0.3552 - val_accuracy: 0.8412Epoch 11/50141/141 [==============================] - 12s 84ms/step - loss: 0.3362 - accuracy: 0.8616 - val_loss: 0.3004 - val_accuracy: 0.8804Epoch 12/50141/141 [==============================] - 12s 85ms/step - loss: 0.3277 - accuracy: 0.8616 - val_loss: 0.2974 - val_accuracy: 0.8733Epoch 13/50141/141 [==============================] - 12s 85ms/step - loss: 0.3243 - accuracy: 0.8589 - val_loss: 0.2732 - val_accuracy: 0.8931Epoch 14/50141/141 [==============================] - 12s 84ms/step - loss: 0.3324 - accuracy: 0.8563 - val_loss: 0.2568 - val_accuracy: 0.8941Epoch 15/50141/141 [==============================] - 12s 84ms/step - loss: 0.3071 - accuracy: 0.8701 - val_loss: 0.2706 - val_accuracy: 0.8911Epoch 16/50141/141 [==============================] - 12s 84ms/step - loss: 0.3114 - accuracy: 0.8696 - val_loss: 0.2503 - val_accuracy: 0.9059Epoch 17/50141/141 [==============================] - 12s 85ms/step - loss: 0.2978 - accuracy: 0.8794 - val_loss: 0.2853 - val_accuracy: 0.8896Epoch 18/50141/141 [==============================] - 12s 85ms/step - loss: 0.3029 - accuracy: 0.8725 - val_loss: 0.2458 - val_accuracy: 0.9033Epoch 19/50141/141 [==============================] - 12s 84ms/step - loss: 0.2988 - accuracy: 0.8721 - val_loss: 0.2713 - val_accuracy: 0.8916Epoch 20/50141/141 [==============================] - 12s 88ms/step - loss: 0.2960 - accuracy: 0.8747 - val_loss: 0.2649 - val_accuracy: 0.8926Epoch 21/50141/141 [==============================] - 13s 92ms/step - loss: 0.2901 - accuracy: 0.8819 - val_loss: 0.2611 - val_accuracy: 0.8957Epoch 22/50141/141 [==============================] - 12s 89ms/step - loss: 0.2879 - accuracy: 0.8821 - val_loss: 0.2497 - val_accuracy: 0.8947Epoch 23/50141/141 [==============================] - 12s 88ms/step - loss: 0.2831 - accuracy: 0.8817 - val_loss: 0.2396 - val_accuracy: 0.9069Epoch 24/50141/141 [==============================] - 12s 89ms/step - loss: 0.2856 - accuracy: 0.8799 - val_loss: 0.2386 - val_accuracy: 0.9059Epoch 25/50141/141 [==============================] - 12s 87ms/step - loss: 0.2834 - accuracy: 0.8817 - val_loss: 0.2472 - val_accuracy: 0.9048Epoch 26/50141/141 [==============================] - 12s 88ms/step - loss: 0.3038 - accuracy: 0.8768 - val_loss: 0.2792 - val_accuracy: 0.8835Epoch 27/50141/141 [==============================] - 13s 91ms/step - loss: 0.2786 - accuracy: 0.8854 - val_loss: 0.2326 - val_accuracy: 0.9079Epoch 28/50141/141 [==============================] - 12s 86ms/step - loss: 0.2692 - accuracy: 0.8846 - val_loss: 0.2325 - val_accuracy: 0.9115Epoch 29/50141/141 [==============================] - 12s 88ms/step - loss: 0.2770 - accuracy: 0.8841 - val_loss: 0.2507 - val_accuracy: 0.8972Epoch 30/50141/141 [==============================] - 13s 92ms/step - loss: 0.2751 - accuracy: 0.8886 - val_loss: 0.2329 - val_accuracy: 0.9104Epoch 31/50141/141 [==============================] - 12s 88ms/step - loss: 0.2902 - accuracy: 0.8785 - val_loss: 0.2901 - val_accuracy: 0.8758Epoch 32/50141/141 [==============================] - 13s 94ms/step - loss: 0.2665 - accuracy: 0.8915 - val_loss: 0.2314 - val_accuracy: 0.9089Epoch 33/50141/141 [==============================] - 13s 91ms/step - loss: 0.2797 - accuracy: 0.8805 - val_loss: 0.2708 - val_accuracy: 0.8921Epoch 34/50141/141 [==============================] - 13s 90ms/step - loss: 0.2895 - accuracy: 0.8799 - val_loss: 0.2332 - val_accuracy: 0.9140Epoch 35/50141/141 [==============================] - 13s 93ms/step - loss: 0.2696 - accuracy: 0.8857 - val_loss: 0.2512 - val_accuracy: 0.8972Epoch 36/50141/141 [==============================] - 13s 90ms/step - loss: 0.2641 - accuracy: 0.8868 - val_loss: 0.2304 - val_accuracy: 0.9104Epoch 37/50141/141 [==============================] - 13s 94ms/step - loss: 0.2675 - accuracy: 0.8895 - val_loss: 0.2706 - val_accuracy: 0.8830Epoch 38/50141/141 [==============================] - 12s 88ms/step - loss: 0.2699 - accuracy: 0.8839 - val_loss: 0.2285 - val_accuracy: 0.9053Epoch 39/50141/141 [==============================] - 12s 87ms/step - loss: 0.2577 - accuracy: 0.8917 - val_loss: 0.2469 - val_accuracy: 0.9043Epoch 40/50141/141 [==============================] - 12s 87ms/step - loss: 0.2547 - accuracy: 0.8948 - val_loss: 0.2205 - val_accuracy: 0.9074Epoch 41/50141/141 [==============================] - 12s 86ms/step - loss: 0.2553 - accuracy: 0.8930 - val_loss: 0.2494 - val_accuracy: 0.9038Epoch 42/50141/141 [==============================] - 14s 97ms/step - loss: 0.2705 - accuracy: 0.8883 - val_loss: 0.2263 - val_accuracy: 0.9109Epoch 43/50141/141 [==============================] - 12s 88ms/step - loss: 0.2521 - accuracy: 0.8926 - val_loss: 0.2319 - val_accuracy: 0.9084Epoch 44/50141/141 [==============================] - 12s 84ms/step - loss: 0.2694 - accuracy: 0.8850 - val_loss: 0.2199 - val_accuracy: 0.9109Epoch 45/50141/141 [==============================] - 12s 83ms/step - loss: 0.2601 - accuracy: 0.8901 - val_loss: 0.2318 - val_accuracy: 0.9079Epoch 46/50141/141 [==============================] - 12s 83ms/step - loss: 0.2535 - accuracy: 0.8917 - val_loss: 0.2342 - val_accuracy: 0.9089Epoch 47/50141/141 [==============================] - 12s 84ms/step - loss: 0.2584 - accuracy: 0.8897 - val_loss: 0.2238 - val_accuracy: 0.9089Epoch 48/50141/141 [==============================] - 12s 83ms/step - loss: 0.2580 - accuracy: 0.8944 - val_loss: 0.2219 - val_accuracy: 0.9120Epoch 49/50141/141 [==============================] - 12s 83ms/step - loss: 0.2514 - accuracy: 0.8895 - val_loss: 0.2225 - val_accuracy: 0.9150Epoch 50/50141/141 [==============================] - 12s 83ms/step - loss: 0.2483 - accuracy: 0.8977 - val_loss: 0.2370 - val_accuracy: 0.9084
预测代码如下:
model = tf.keras.models.load_model(global_dir + "/Model/1")image = cv.resize(image,(50,50)) image= image.astype('float32')/255image= np.expand_dims(image, axis=0)predictions = model.predict(image)top = np.array(tf.argmax(predictions, 1))result = top[0]
这个函数收集所有输入图片并保存分类(0,1),然后打乱数组。之后,我遍历数组,预测图片并将结果与实际类别进行比较。
def test_model(): dir_good = os.fsencode(global_dir + "/Contours/Clover") dir_bad = os.fsencode(global_dir + "/Contours/Grass") test = [] for file2 in os.listdir(dir_good): filename2 = os.fsdecode(file2) if (filename2.endswith(".jpg")): test.append([0,os.path.join(global_dir + "/Contours/Clover", filename2)]) for file2 in os.listdir(dir_bad): filename2 = os.fsdecode(file2) if (filename2.endswith(".jpg")): test.append([1,os.path.join(global_dir + "/Contours/Grass", filename2)]) random.shuffle(test) count = 0 right = 0 for i in range(0,len(test)): tmp = cv.imread(test[i][1]) result = predict_image(tmp) #<--- 这个函数已经在上面引用过了 count += 1 right += (1 if result == test[i][0] else 0) print(str(test[i][0]) + "->" + str(result),count,right,round(right/count*100,1))
提前感谢您的帮助!祝好,Seb
回答:
正如我们之前讨论的,您使用cv2.imread
来加载图片,该函数以BGR格式加载颜色通道。而Keras数据生成器在内部以RGB格式加载图片。在进行推理之前,您必须反转通道:
tmp = tmp[...,::-1]