我使用Keras和MNIST数据集构建了一个神经网络,现在我想用它来处理实际手写数字的照片。当然,我不期望结果完美,但我目前得到的结果还有很大的提升空间。
首先,我用一些我最清晰的书写风格测试了单个数字的照片。它们是方形的,尺寸和颜色与MNIST数据集中的图像相同。它们被保存在名为individual_test的文件夹中,例如:7(2)_digit.jpg。
网络经常对错误的结果非常确定,下面我给出一个例子:
我对这张图片得到的结果如下:
result: 3 . probabilities: [1.9963557196245318e-10, 7.241294497362105e-07, 0.02658148668706417, 0.9726449251174927, 2.5416460047722467e-08, 2.6078915027483163e-08, 0.00019745019380934536, 4.8302300825753264e-08, 0.0005754049634560943, 2.8358477788259506e-09]
所以网络97%确定这是3,而这张图片绝不是唯一的案例。在38张图片中,只有16张被正确识别。让我震惊的是,网络对其结果如此确定,尽管它与正确结果相去甚远。
编辑
在prepare_image中添加阈值后(img = cv2.threshold(img, 0.1, 1, cv2.THRESH_BINARY_INV)[1]
),性能略有提升。现在它正确识别了38张图片中的19张,但对于包括上面显示的图片在内的一些图片,它仍然对错误的结果非常确定。现在我得到的结果是:
result: 3 . probabilities: [1.0909866760000497e-11, 1.1584616004256532e-06, 0.27739930152893066, 0.7221096158027649, 1.900260038212309e-08, 6.555900711191498e-08, 4.479645940591581e-05, 6.455550760620099e-07, 0.0004443934594746679, 1.0013242457418414e-09]
所以现在它只对结果72%确定,这比之前好一些,但仍然…
我可以做些什么来提升性能?我能更好地准备我的图片吗?或者我应该将自己的图片添加到训练数据中?如果是,我该如何操作?
编辑
这是上面显示的图片在应用prepare_image后的样子:
使用阈值后,同样的图片看起来是这样的:
相比之下,这是MNIST数据集提供的一张图片:
在我看来,它们看起来相当相似。我该如何改进这个?
这是我的代码(包括阈值处理):
# 导入Keras和MNIST数据集from tensorflow.keras.datasets import mnistfrom tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Densefrom keras.utils import np_utils# Keras使用numpy数组,所以需要numpyimport numpy as np# 导入处理图片的库import matplotlib.pyplot as pltimport PILimport cv2# 导入测试用的库import randomimport osclass mnist_network(): def __init__(self): """ 加载数据,创建并训练模型 """ # 加载数据 (X_train, y_train), (X_test, y_test) = mnist.load_data() # 将28*28的图像展平为每个图像的784向量 num_pixels = X_train.shape[1] * X_train.shape[2] X_train = X_train.reshape((X_train.shape[0], num_pixels)).astype('float32') X_test = X_test.reshape((X_test.shape[0], num_pixels)).astype('float32') # 将输入从0-255归一化到0-1 X_train = X_train / 255 X_test = X_test / 255 # 对输出进行one-hot编码 y_train = np_utils.to_categorical(y_train) y_test = np_utils.to_categorical(y_test) num_classes = y_test.shape[1] # 创建模型 self.model = Sequential() self.model.add(Dense(num_pixels, input_dim=num_pixels, kernel_initializer='normal', activation='relu')) self.model.add(Dense(num_classes, kernel_initializer='normal', activation='softmax')) # 编译模型 self.model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) # 训练模型 self.model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200, verbose=2) self.train_img = X_train self.train_res = y_train self.test_img = X_test self.test_res = y_test def predict_result(self, img, show = False): """ 预测图片(向量)中的数字 """ assert type(img) == np.ndarray and img.shape == (784,) if show: img = img.reshape((28, 28)) # 显示图片 plt.imshow(img, cmap='Greys') plt.show() img = img.reshape(img.shape[0] * img.shape[1]) num_pixels = img.shape[0] # 实际的数字 res_number = np.argmax(self.model.predict(img.reshape(-1,num_pixels)), axis = 1) # 概率 res_probabilities = self.model.predict(img.reshape(-1,num_pixels)) return (res_number[0], res_probabilities.tolist()[0]) # 我们只需要第一个元素,因为它们只有一个 def prepare_image(self, img, show = False): """ 准备partial_img_rec中使用的部分图像,通过将它们转换成网络能够处理的numpy数组 """ # 转换为灰度 img = img.convert("L") # 调整图像尺寸为28 * 28 img = img.resize((28,28), PIL.Image.ANTIALIAS) # 反转颜色,因为训练图像有黑色背景 #img = PIL.ImageOps.invert(img) # 转换为向量 img = np.asarray(img, "float32") img = img / 255. img[img < 0.5] = 0. img = cv2.threshold(img, 0.1, 1, cv2.THRESH_BINARY_INV)[1] if show: plt.imshow(img, cmap = "Greys") # 将图像展平为28*28 = 784向量 num_pixels = img.shape[0] * img.shape[1] img = img.reshape(num_pixels) return img def partial_img_rec(self, image, upper_left, lower_right, results=[], show = False): """ partial是图像的一部分 """ left_x, left_y = upper_left right_x, right_y = lower_right print("当前测试部分: ", upper_left, lower_right) print("结果: ", results) # 停止递归的条件:我们已经到达图片的全宽 width, height = image.size if right_x > width: return results partial = image.crop((left_x, left_y, right_x, right_y)) if show: partial.show() partial = self.prepare_image(partial) step = height // 10 # 这部分图像中有数字吗? res, prop = self.predict_result(partial) print("结果: ", res, ". 概率: ", prop) # 只有当网络至少50%确定时才计算这个结果 if prop[res] >= 0.5: results.append(res) # 步长是部分图像大小的80%(相当于原始图像的高度) step = int(height * 0.8) print("找到有效结果") else: # 如果没有找到数字,我们采取更小的步长 step = height // 20 print("步长: ", step) # 递归调用,修改位置(移动步长变量) return self.partial_img_rec(image, (left_x + step, left_y), (right_x + step, right_y), results = results) def individual_digits(self, img): """ 使用partial_img_rec来预测方形图像中的单个数字 """ assert type(img) == PIL.JpegImagePlugin.JpegImageFile or type(img) == PIL.PngImagePlugin.PngImageFile or type(img) == PIL.Image.Image return self.partial_img_rec(img, (0,0), (img.size[0], img.size[1]), results=[]) def test_individual_digits(self): """ 使用一些单个数字(形状:方形)测试partial_img_rec,这些数字保存在'individual_test'文件夹中,文件名模式为'number_digit.jpg' """ cnt_right, cnt_wrong = 0,0 folder_content = os.listdir(".\individual_test") for imageName in folder_content: # 图像文件必须是jpg或png assert imageName[-4:] == ".jpg" or imageName[-4:] == ".png" correct_res = int(imageName[0]) image = PIL.Image.open(".\\individual_test\\" + imageName).convert("L") # 这个测试中只有方形图像 if image.size[0] != image.size[1]: print(imageName, " 的比例不正确: ", image.size,". 它必须是方形的。") continue predicted_res = self.individual_digits(image) if predicted_res == []: print("无法预测 ", imageName) else: predicted_res = predicted_res[0] if predicted_res != correct_res: print("partial_img-rec出错!预测为 ", predicted_res, ". 正确结果应该是 ", correct_res) cnt_wrong += 1 else: cnt_right += 1 print("正确预测了 ",imageName) print(cnt_right, " 个中的 ", cnt_right + cnt_wrong," 个数字被正确识别。因此成功率为 ", (cnt_right / (cnt_right + cnt_wrong)) * 100," %。") def multiple_digits(self, img): """ 输入一个没有多余空白包围数字的图像 """ #assert type(img) == myImage width, height = img.size # 从图像的第一个方形部分开始 res_list = self.partial_img_rec(img, (0,0),(height ,height), results = []) res_str = "" for elem in res_list: res_str += str(elem) return res_str def test_multiple_digits(self): """ 使用保存在'multi_test'文件夹中的一些图像测试'multiple_digits'函数。这些图像包含多个手写数字,周围没有太多空白。正确的解决方案保存在文件名中,后跟字符'_'。 """ cnt_right, cnt_wrong = 0,0 folder_content = os.listdir(".\multi_test") for imageName in folder_content: # 图像文件必须是jpg或png assert imageName[-4:] == ".jpg" or imageName[-4:] == ".png" image = PIL.Image.open(".\\multi_test\\" + imageName).convert("L") correct_res = imageName.split("_")[0] predicted_res = self.multiple_digits(image) if correct_res == predicted_res: cnt_right += 1 else: cnt_wrong += 1 print("multiple_digits出错!网络预测为 ", predicted_res, " 但正确结果应该是 ", correct_res) print("网络正确预测了 ", cnt_right, " 张中的 ", cnt_right + cnt_wrong, " 张图片。成功率为 ", cnt_right / (cnt_right + cnt_wrong) * 100, "%。")network = mnist_network()# 这是上面显示的图像result = network.individual_digits(PIL.Image.open(".\individual_test\\7(2)_digit.jpg"))
回答:
更新:
在这一特定任务中,你有三种选择来获得更好的性能:
- 使用卷积网络,因为它在处理具有空间数据的任务(如图像)时表现更好,并且是更具生成性的分类器,就像这个例子一样。
- 使用或创建和/或生成更多你类型的图片,并用它们训练你的网络,使你的网络也能学会识别它们。
- 预处理你的图像,使其更好地与你之前训练网络时使用的原始MNIST图像对齐。
我刚刚做了一个实验。我检查了MNIST图像中每个数字的代表性。我拿了你的图片,并进行了之前建议的一些预处理,比如:
1. 向下设定了阈值,消除了背景噪音,因为原始MNIST数据仅对空白背景设定了最小的阈值:
image[image < 0.1] = 0.
2. 令人惊讶的是,图像内数字的大小被证明是关键的,所以我在28 x 28的图像内缩放了数字,例如,我们在数字周围有更多的填充空间。
3. 我反转了图像,因为Keras中的MNIST数据也是反转的。
image = ImageOps.invert(image)
4. 最后,使用与训练时相同的方法缩放数据:
image = image / 255.
在预处理后,我使用参数epochs=12, batch_size=200
训练了模型,结果如下:
结果:1,概率:0.6844741106033325
result: **1** . probabilities: [2.0584749904628552e-07, 0.9875971674919128, 5.821426839247579e-06, 4.979299319529673e-07, 0.012240586802363396, 1.1566483948399764e-07, 2.382085284580171e-08, 0.00013023221981711686, 9.620113416985987e-08, 2.5273093342548236e-05]
结果:6,概率:0.9221984148025513
result: 6 . probabilities: [9.130864782491699e-05, 1.8290626258021803e-07, 0.00020504613348748535, 2.1564576968557958e-07, 0.0002401985548203811, 0.04510130733251572, 0.9221984148025513, 1.9014490248991933e-07, 0.03216308355331421, 3.323434683011328e-08]
结果:7,概率:0.7105212807655334注意:
result: 7 . probabilities: [1.0372193770535887e-08, 7.988557626958936e-06, 0.00031014863634482026, 0.0056108818389475346, 2.434678014751057e-09, 3.2280522077599016e-07, 1.4190952857262573e-09, 0.9940618872642517, 1.612859932720312e-06, 7.102244126144797e-06]