如何使用Tensorflow预测用户输入的图像?

在我的项目中,我使用tensorflow来预测手写用户输入。

基本上我使用了这个数据集:https://www.kaggle.com/rishianand/devanagari-character-set,并创建了一个模型。我使用matplotlib来查看由像素生成的图像。

我的代码主要用于训练数据,但我想要提升一下。通过CV2,我创建了一个GUI,允许用户绘制尼泊尔字母。之后,我有分支告诉程序将图像保存到计算机中。

这是我为此写的代码片段:

#creating a forloop to show the imagewhile True:    img=cv2.imshow('window', win) #showing the window    k= cv2.waitKey(1)     if k==ord('c'):        win= np.zeros((500,500,3), dtype='float64') #creating a new image    #saving the image as a file to then resize it    if k==ord('s'):        cv2.imwrite("nepali_character.jpg", win)        img= cv2.imread("nepali_character.jpg")        cv2.imshow('char', img)        #trying to resize the image using Pillow        size=(32,32)        #create a while loop(make the user print stuff until they print something that STOPS it)        im= Image.open("nepali_character.jpg")        out=im.resize(size)        l= out.save('resized.jpg')        imgout= cv2.imread('resized.jpg')        cv2.imshow("out", imgout)        #finding the pixels of the image, will be printed as a matrix        pix= cv2.imread('resized.jpg', 1)        print(pix)    if k==ord('q'): #if k is 27 then we break the window        cv2.destroyAllWindows()        break

我调整图像大小,因为这是数据集中的数据尺寸。

现在我的问题是如何通过tensorflow预测这个字母是什么。

当我问我的老师这个问题时,他说将其放入我的数据文件中,然后将其视为训练数据,然后查看权重,并选择最大的权重?

但我对如何将这个图像放入那个数据文件感到困惑?

如果有人有任何关于如何获取用户输入然后进行预测的建议,将不胜感激


回答:

了解数据集:

  1. 图像大小为32 x 32
  2. 有46个不同的字符/字母
['character_10_yna', 'character_11_taamatar', 'character_12_thaa', 'character_13_daa', 'character_14_dhaa', 'character_15_adna', 'character_16_tabala', 'character_17_tha', 'character_18_da', 'character_19_dha', 'character_1_ka', 'character_20_na', 'character_21_pa', 'character_22_pha', 'character_23_ba', 'character_24_bha', 'character_25_ma', 'character_26_yaw', 'character_27_ra', 'character_28_la', 'character_29_waw', 'character_2_kha', 'character_30_motosaw', 'character_31_petchiryakha', 'character_32_patalosaw', 'character_33_ha', 'character_34_chhya', 'character_35_tra', 'character_36_gya', 'character_3_ga', 'character_4_gha', 'character_5_kna', 'character_6_cha', 'character_7_chha', 'character_8_ja', 'character_9_jha', 'digit_0', 'digit_1', 'digit_2', 'digit_3', 'digit_4', 'digit_5', 'digit_6', 'digit_7', 'digit_8', 'digit_9']

由于您的图像在文件夹中被分类Traing Folder

因此keras实现将是:

import matplotlib.pyplot as pltimport numpy as npimport osimport PILimport tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layersfrom tensorflow.keras.models import Sequentialimport pathlibdataDir = "/xx/xx/xx/xx/datasets/Devanagari/drive-download-20210601T224146Z-001/Train"data_dir = keras.utils.get_file(dataDir, 'file://'+dataDir)data_dir = pathlib.Path(data_dir)image_count = len(list(data_dir.glob('*/*.png')))print(image_count)batch_size = 32img_height = 180 # 为了更好的性能,增加尺寸img_width = 180 # 为了更好的性能,增加尺寸train_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="training",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="validation",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size)class_names = train_ds.class_namesprint(class_names) # 46个类别

关于缓存和归一化,请参考tensorflow教程

AUTOTUNE = tf.data.experimental.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))image_batch, labels_batch = next(iter(normalized_ds))first_image = image_batch[0]print(np.min(first_image), np.max(first_image))

模型设置 编译和训练

num_classes = 46model = Sequential([  layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),  layers.Conv2D(16, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Conv2D(32, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Conv2D(64, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Flatten(),  layers.Dense(128, activation='relu'),  layers.Dense(num_classes)])model.compile(optimizer='adam',              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),              metrics=['accuracy'])epochs=10history = model.fit(  train_ds,  validation_data=val_ds,  epochs=epochs)

这将产生以下结果(非常有希望!)

Epoch 10/101955/1955 [==============================] - 924s 472ms/step - loss: 0.0201 - accuracy: 0.9932 - val_loss: 0.2267 - val_accuracy: 0.9504

保存模型(训练需要时间,所以最好保存模型)

!mkdir -p saved_modelmodel.save('saved_model/my_model')

加载模型:

loaded_model = tf.keras.models.load_model('saved_model/my_model')# 检查其架构loaded_model.summary()

现在是最后的任务,获取预测。一种方法如下:

import cv2im2=cv2.imread('datasets/Devanagari/drive-download-20210601T224146Z-001/Test/character_3_ga/3711.png')im2=cv2.resize(im2, (180,180)) # 调整大小到180,180,因为模型是在这个尺寸上训练的print(im2.shape)img2 = tf.expand_dims(im2, 0) # 扩展维度意味着将形状从(180, 180, 3)更改为(1, 180, 180, 3)print(img2.shape)predictions = loaded_model.predict(img2)score = tf.nn.softmax(predictions[0]) # 为每个输出获取softmaxprint(    "这个图像最有可能属于{},置信度为{:.2f}%"    .format(class_names[np.argmax(score)], 100 * np.max(score))) # 获取np.argmax,意味着给我提供概率最大的索引,在这种情况下得到了29。这回答了您从老师那里得到的回应,即“最大权重”
(180, 180, 3)(1, 180, 180, 3)这个图像最有可能属于character_3_ga,置信度为100.00%

另一种方法是通过在线方式,您正在尝试实现的。图像形状需要是(1, 180, 180, 3),对于这个例子,或者如果没有调整大小,可以是(1, 32, 32, 3)。然后将其输入到预测中。类似于下面的代码

out=im.resize(size)out = tf.expand_dims(out, 0)predictions = loaded_model.predict(out)score = tf.nn.softmax(predictions[0]) # 为每个输出获取softmaxprint(    "这个图像最有可能属于{},置信度为{:.2f}%"    .format(class_names[np.argmax(score)], 100 * np.max(score))) 

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注