预测鼠标书写的数字

我想预测用鼠标书写的数字。我使用TensorFlow创建了一个模型并训练了整个数据集。

当我书写一个数字并尝试预测时,得到的答案准确度较低。

请建议一些克服这个问题的办法。

源代码如下:

import cv2import numpy as np import matplotlib.pyplot as pltfrom PIL import Imageimport tensorflow as tfdef plot_digit(data):    image = data.reshape(28, 28)    plt.imshow(image, interpolation='nearest')    plt.axis('off')mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([  tf.keras.layers.Flatten(input_shape=(28, 28)),  tf.keras.layers.Dense(128, activation='relu'),  tf.keras.layers.Dropout(0.2),  tf.keras.layers.Dense(128, activation='relu'),  tf.keras.layers.Dense(10)])predictions = model(x_train[:1]).numpy()tf.nn.softmax(predictions).numpy()loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)loss_fn(y_train[:1], predictions).numpy()model.compile(optimizer='adam',              loss=loss_fn,              metrics=['accuracy'])model.fit(x_train, y_train, epochs=10)model.evaluate(x_test,  y_test, verbose=2)drawing = False # true if mouse is pressedpt1_x , pt1_y = None , None# mouse callback functiondef line_drawing(event,x,y,flags,param):    global pt1_x,pt1_y,drawing    if event==cv2.EVENT_LBUTTONDOWN:        drawing=True        pt1_x,pt1_y=x,y    elif event==cv2.EVENT_MOUSEMOVE:        if drawing==True:            cv2.line(img,(pt1_x,pt1_y),(x,y),color=(255,255,255),thickness=3)            pt1_x,pt1_y=x,y    elif event==cv2.EVENT_LBUTTONUP:        drawing=False        cv2.line(img,(pt1_x,pt1_y),(x,y),color=(255,255,255),thickness=3)        img = np.zeros((200,200), np.uint8)cv2.namedWindow('test draw')cv2.setMouseCallback('test draw',line_drawing)while(1):    cv2.imshow('test draw',img)    if cv2.waitKey(1) & 0xFF == 27:        breakcv2.destroyAllWindows()img = Image.fromarray(img)foo = img.resize((28,28),Image.ANTIALIAS)foo = np.array(foo)/255.0plot_digit(foo)np.argmax(model.predict(foo.reshape(1,28,28)))

当我写7时,它预测为6。但当我绘制我画的图形时,显示的是7。


回答:

这可能是许多原因造成的。一些建议如下:

1) 可能是调整大小的问题?在200,200的图像上使用thickness=3,调整到(28,28)后就相当于thickness=1,这不再符合MNIST数据集的特征。尝试可视化一些MNIST数据和鼠标书写的数据,看看它们在(28,28)尺寸下是否真的相似。

2) 可能是模型对手写数字过拟合了?考虑在模型中使用卷积层,我认为这在这种情况下可以缓解这个问题。

3) 可能是可视化的问题?我看到你在可视化图像时同时使用了ANTIALIASnearest。尝试去掉nearest。你还看到你期望的结果吗?

如果你能发布一些你绘制的图像,那将有助于解决问题。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注