我想预测用鼠标书写的数字。我使用TensorFlow创建了一个模型并训练了整个数据集。
当我书写一个数字并尝试预测时,得到的答案准确度较低。
请建议一些克服这个问题的办法。
源代码如下:
import cv2import numpy as np import matplotlib.pyplot as pltfrom PIL import Imageimport tensorflow as tfdef plot_digit(data): image = data.reshape(28, 28) plt.imshow(image, interpolation='nearest') plt.axis('off')mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10)])predictions = model(x_train[:1]).numpy()tf.nn.softmax(predictions).numpy()loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)loss_fn(y_train[:1], predictions).numpy()model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])model.fit(x_train, y_train, epochs=10)model.evaluate(x_test, y_test, verbose=2)drawing = False # true if mouse is pressedpt1_x , pt1_y = None , None# mouse callback functiondef line_drawing(event,x,y,flags,param): global pt1_x,pt1_y,drawing if event==cv2.EVENT_LBUTTONDOWN: drawing=True pt1_x,pt1_y=x,y elif event==cv2.EVENT_MOUSEMOVE: if drawing==True: cv2.line(img,(pt1_x,pt1_y),(x,y),color=(255,255,255),thickness=3) pt1_x,pt1_y=x,y elif event==cv2.EVENT_LBUTTONUP: drawing=False cv2.line(img,(pt1_x,pt1_y),(x,y),color=(255,255,255),thickness=3) img = np.zeros((200,200), np.uint8)cv2.namedWindow('test draw')cv2.setMouseCallback('test draw',line_drawing)while(1): cv2.imshow('test draw',img) if cv2.waitKey(1) & 0xFF == 27: breakcv2.destroyAllWindows()img = Image.fromarray(img)foo = img.resize((28,28),Image.ANTIALIAS)foo = np.array(foo)/255.0plot_digit(foo)np.argmax(model.predict(foo.reshape(1,28,28)))
当我写7时,它预测为6。但当我绘制我画的图形时,显示的是7。
回答:
这可能是许多原因造成的。一些建议如下:
1) 可能是调整大小的问题?在200,200
的图像上使用thickness=3
,调整到(28,28)
后就相当于thickness=1
,这不再符合MNIST数据集的特征。尝试可视化一些MNIST数据和鼠标书写的数据,看看它们在(28,28)
尺寸下是否真的相似。
2) 可能是模型对手写数字过拟合了?考虑在模型中使用卷积层,我认为这在这种情况下可以缓解这个问题。
3) 可能是可视化的问题?我看到你在可视化图像时同时使用了ANTIALIAS
和nearest
。尝试去掉nearest
。你还看到你期望的结果吗?
如果你能发布一些你绘制的图像,那将有助于解决问题。