### Python, Tensorflow 导入非数据集图像

我在进行一个关于Python机器学习的学校项目。我使用Tensorflow创建了一个线性分类器,并以超过90%的准确率学习了MNIST数据集。

预测数据集的测试数据没有问题,但问题在于我想导入不在测试数据集中的数据(可能是用画图软件创建的一张图片)。

我为我的展示创建了一个简单的GUI,它在使用时也运行正常,但例如使用.png图片时就不行了。

我尝试过使用Pillow,但看起来效果不好。

你能帮帮我吗?我会接受任何建议。非常感谢。

这是Tensorflow的代码:

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
learn = tf.contrib.learn
tf.logging.set_verbosity(tf.logging.ERROR)
global i, test_labels
i = 0
def display(i):
   img = test_data[i]
   plt.title('Example %d, label %d' % (i, test_labels[i]))
   plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r)
   plt.show()
global mnist
mnist = learn.datasets.load_dataset("mnist")
test_data = mnist.test.images
test_labels = np.array(mnist.test.labels, dtype=np.int32)
def train_me(max_examples, batch, step):
   data = mnist.train.images
   labels = np.array(mnist.train.labels, dtype=np.int32)
   data = data[:max_examples]
   labels = labels[:max_examples]
   feature_columns = learn.infer_real_valued_columns_from_input(data)
   cls = learn.LinearClassifier(feature_columns=feature_columns, 
   n_classes=10)
   cls.fit(data, labels, batch_size=batch, steps=step)
   return cls
def test_me(cls):
   im = Image.open("dva-test.png")
   global prediction
   prediction = cls.predict(im, as_iterable=False)

这是GUI的代码:

import sys
import digits as dig
from PyQt5.QtWidgets import (QApplication, QWidget, QToolTip, 
 QPushButton, QMessageBox, QDesktopWidget, QMainWindow, 
                        QLabel, QAction, QFileDialog)
from PyQt5.QtGui import QIcon
class Gui(QMainWindow):
   def __init__(self):
       super().__init__()
       self.init_ui()
   def init_ui(self):
       self.setFixedSize(500, 200)
       self.center()
       self.statusBar().showMessage('Not trained')
       exAct = QAction('Exit', self)
       exAct.setShortcut('Ctrl+Q')
       exAct.triggered.connect(self.close)
       impAct = QAction('Import picture', self)
       impAct.setShortcut('Ctrl+I')
       impAct.triggered.connect(self.file_import)
       menubar = self.menuBar()
       fileMenu = menubar.addMenu('&File')
       fileMenu.addAction(impAct)
       fileMenu.addAction(exAct)
       trainBtn = QPushButton('Train', self)
       trainBtn.resize(trainBtn.sizeHint())
       trainBtn.move(155, 120)
       trainBtn.clicked.connect(self.trainning)
       testBtn = QPushButton('Test', self)
       testBtn.resize(trainBtn.sizeHint())
       testBtn.move(255, 120)
       testBtn.clicked.connect(self.testing)
       text = QLabel("Please import file and train the classifier before testing.", self)
       text.resize(text.sizeHint())
       text.move(120, 40)
       self.setWindowIcon(QIcon('icon.png'))
       self.setWindowTitle('Digits')
       self.show()
   def trainning(self):
       global classifier
       classifier = dig.train_me(10000, 100, 1000)
       classifier.evaluate(dig.test_data, dig.test_labels)
       self.statusBar().showMessage('Accuracy: ' + 
                                    str(classifier.evaluate(dig.test_data, 
       dig.test_labels)['accuracy']))
   def testing(self):
       dig.i = 2
       dig.test_me(classifier)
       self.statusBar().showMessage("Predicted %d, label: %d" % (dig.prediction, dig.test_labels[dig.i]))
   def file_import(self):
           name = QFileDialog.getOpenFileName(self, 'Import File')
           print(name)
   def closeEvent(self, event):
       reply = QMessageBox.question(self, 'Message', "Are you sure you want to exit ?", 
                                QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
       if reply == QMessageBox.Yes:
           event.accept()
       else:
           event.ignore()
   def center(self):
       qr = self.frameGeometry()
       cp = QDesktopWidget().availableGeometry().center()
       qr.moveCenter(cp)
       self.move(qr.topLeft())
if __name__ == '__main__':
   app = QApplication(sys.argv)
   ui = Gui()
   sys.exit(app.exec_())

回答:

已解决:

Tensorflow只接受一维数组,而我的图像是三维数组,形状为[28, 28, 3]。所以我移除了RGB维度,并展平了二维数组。

然后我将结果导入到Tensorflow分类器中,但意识到我需要反转颜色,因此数组中的每个零应等于1,每个1应等于零。

这是代码:

    im = mpimg.imread('dva-test.png')
    im = im[:, :, 0]
    im = im.ravel()
    for j in range(len(im)):
        if im[j] == 0:
            im[j] = 1
        elif im[j] == 1:
            im[j] = 0 
    global prediction
    prediction = cls.predict(np.array([im], dtype=float), as_iterable=False)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注