使用Tensorflow中的CNN进行预测

我正在尝试创建一个二元图像分类器,用于检测某人是否佩戴了外科口罩,参考了TensorFlow网站上的“猫与狗”示例(https://www.tensorflow.org/tutorials/images/classification)。

我已经创建了一个小型数据集,包含了一些佩戴外科口罩的人的图片和一些未佩戴口罩的人的图片,并训练了我的CNN,准确率约为70%,目前还可以。但问题是,如何进行预测呢?“猫与狗”示例在数据增强部分就停止了。

目前我并不担心准确率,只是想知道如何从我的模型中获取预测结果。

这是我的代码:

import sysimport timeimport numpy as npimport matplotlib.pyplot as pltimport osimport cv2import randomimport kerasimport tensorflow as tffrom tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2Dfrom tensorflow.keras.callbacks import TensorBoardfrom tensorflow.keras.preprocessing.image import ImageDataGeneratorfrom keras.optimizers import adamIMG_SIZE = 100 # 图像尺寸batch_size = 100 # 一次性输入到神经网络的数据量epochs = 1 # 数据通过神经网络的次数training_data = []#img_array = []new_array = []#######  数据集位置  #######TRAIN_DIR = 'C:/Users/Alex/Google Drive/Colab Notebooks/MaskDetector/Train/' # 创建存储训练图像目录路径的变量VALIDATION_DIR = 'C:/Users/Alex/Google Drive/Colab Notebooks/MaskDetector/Validate/' # 创建存储验证图像目录路径的变量TEST_DIR = 'C:/Users/Alex/Google Drive/Colab Notebooks/MaskDetector/Test/' # 创建存储测试图像目录路径的变量CATEGORIES = ['MaskOn','MaskOff'] # 类别 'MaskOn' 和 'MaskOff' 与文件夹名称相同TRAIN_DIR_MASKON = os.path.join(TRAIN_DIR, 'MaskOn')  # 训练用佩戴口罩的人的图片目录TRAIN_DIR_MASKOFF = os.path.join(TRAIN_DIR, 'MaskOff')  # 训练用未佩戴口罩的人的图片目录VALIDATION_DIR_MASKON = os.path.join(VALIDATION_DIR, 'MaskOn')  # 验证用佩戴口罩的人的图片目录VALIDATION_DIR_MASKOFF = os.path.join(VALIDATION_DIR, 'MaskOff')  # 验证用未佩戴口罩的人的图片目录##################################################     显示数据集大小  #######num_maskon_tr = len(os.listdir(TRAIN_DIR_MASKON))num_maskoff_tr = len(os.listdir(TRAIN_DIR_MASKOFF))num_maskon_val  = len(os.listdir(VALIDATION_DIR_MASKON))num_maskoff_val = len(os.listdir(VALIDATION_DIR_MASKOFF))total_train = num_maskon_tr + num_maskoff_trtotal_val = num_maskon_val + num_maskoff_val###################### 数据增强  ###############################   翻转   #########image_gen = ImageDataGenerator(rescale=1./255, horizontal_flip=True)train_data_gen = image_gen.flow_from_directory(batch_size=batch_size,                                               directory=TRAIN_DIR,                                               shuffle=True,                                               target_size=(IMG_SIZE, IMG_SIZE))augmented_images = [train_data_gen[0][0][0] for i in range(5)]# 重新使用上面定义并使用的自定义绘图函数来可视化训练图像#########   旋转45°   #########image_gen = ImageDataGenerator(rescale=1./255, rotation_range=45)train_data_gen = image_gen.flow_from_directory(batch_size=batch_size,                                               directory=TRAIN_DIR,                                               shuffle=True,                                               target_size=(IMG_SIZE, IMG_SIZE))augmented_images = [train_data_gen[0][0][0] for i in range(5)]#########   缩放0到10%   ########## zoom_range从0到1,其中1=100%。image_gen = ImageDataGenerator(rescale=1./255, zoom_range=0.5) #train_data_gen = image_gen.flow_from_directory(batch_size=batch_size,                                               directory=TRAIN_DIR,                                               shuffle=True,                                               target_size=(IMG_SIZE, IMG_SIZE))augmented_images = [train_data_gen[0][0][0] for i in range(5)]####################################################################################################### 准备数据输入神经网络 ###train_image_generator = ImageDataGenerator(rescale=1./255)image_gen_val = ImageDataGenerator(rescale=1./255)test_data_generator = ImageDataGenerator(rescale=1./255) # 用于预测吗?目前还不确定val_data_gen = image_gen_val.flow_from_directory(batch_size=batch_size,                                                 directory=VALIDATION_DIR,                                                 target_size=(IMG_SIZE, IMG_SIZE),                                                 class_mode='binary')####################################################################################train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,                                                           directory=TRAIN_DIR,                                                           shuffle=True,                                                           target_size=(IMG_SIZE, IMG_SIZE),                                                           class_mode='binary')#####################################################################################test_generator = test_data_generator.flow_from_directory(TEST_DIR,                                                         target_size=(IMG_SIZE, IMG_SIZE),                                                         batch_size=batch_size,                                                         class_mode="binary",                                                         shuffle=True)############   神经网络模型    ############model = Sequential([    Conv2D(16, 3, padding='same', activation='relu',           input_shape=(IMG_SIZE, IMG_SIZE ,3)),    MaxPooling2D(),    Dropout(0.2),    Conv2D(32, 3, padding='same', activation='relu'),    MaxPooling2D(),    Conv2D(64, 3, padding='same', activation='relu'),    MaxPooling2D(),    Dropout(0.2),    Flatten(),    Dense(512, activation='relu'),    Dense(1)])#####################################################################   编译神经网络    ############model.compile(optimizer='adam',              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),              metrics=['accuracy'])############   执行训练    ############history = model.fit_generator(    train_data_gen,    steps_per_epoch=total_train // batch_size,    epochs=epochs,    validation_data=val_data_gen,    validation_steps=total_val // batch_size)

回答:

你已经找到了model.compile()model.fit_generator() – 你只需要查看文档并找到其他方法。这里有一个链接,它会告诉你如何使用model.predict()。用它来进行你的预测吧。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注