### Python, Tensorflow 导入非数据集图像

我在进行一个关于Python机器学习的学校项目。我使用Tensorflow创建了一个线性分类器,并以超过90%的准确率学习了MNIST数据集。

预测数据集的测试数据没有问题,但问题在于我想导入不在测试数据集中的数据(可能是用画图软件创建的一张图片)。

我为我的展示创建了一个简单的GUI,它在使用时也运行正常,但例如使用.png图片时就不行了。

我尝试过使用Pillow,但看起来效果不好。

你能帮帮我吗?我会接受任何建议。非常感谢。

这是Tensorflow的代码:

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
learn = tf.contrib.learn
tf.logging.set_verbosity(tf.logging.ERROR)
global i, test_labels
i = 0
def display(i):
   img = test_data[i]
   plt.title('Example %d, label %d' % (i, test_labels[i]))
   plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r)
   plt.show()
global mnist
mnist = learn.datasets.load_dataset("mnist")
test_data = mnist.test.images
test_labels = np.array(mnist.test.labels, dtype=np.int32)
def train_me(max_examples, batch, step):
   data = mnist.train.images
   labels = np.array(mnist.train.labels, dtype=np.int32)
   data = data[:max_examples]
   labels = labels[:max_examples]
   feature_columns = learn.infer_real_valued_columns_from_input(data)
   cls = learn.LinearClassifier(feature_columns=feature_columns, 
   n_classes=10)
   cls.fit(data, labels, batch_size=batch, steps=step)
   return cls
def test_me(cls):
   im = Image.open("dva-test.png")
   global prediction
   prediction = cls.predict(im, as_iterable=False)

这是GUI的代码:

import sys
import digits as dig
from PyQt5.QtWidgets import (QApplication, QWidget, QToolTip, 
 QPushButton, QMessageBox, QDesktopWidget, QMainWindow, 
                        QLabel, QAction, QFileDialog)
from PyQt5.QtGui import QIcon
class Gui(QMainWindow):
   def __init__(self):
       super().__init__()
       self.init_ui()
   def init_ui(self):
       self.setFixedSize(500, 200)
       self.center()
       self.statusBar().showMessage('Not trained')
       exAct = QAction('Exit', self)
       exAct.setShortcut('Ctrl+Q')
       exAct.triggered.connect(self.close)
       impAct = QAction('Import picture', self)
       impAct.setShortcut('Ctrl+I')
       impAct.triggered.connect(self.file_import)
       menubar = self.menuBar()
       fileMenu = menubar.addMenu('&File')
       fileMenu.addAction(impAct)
       fileMenu.addAction(exAct)
       trainBtn = QPushButton('Train', self)
       trainBtn.resize(trainBtn.sizeHint())
       trainBtn.move(155, 120)
       trainBtn.clicked.connect(self.trainning)
       testBtn = QPushButton('Test', self)
       testBtn.resize(trainBtn.sizeHint())
       testBtn.move(255, 120)
       testBtn.clicked.connect(self.testing)
       text = QLabel("Please import file and train the classifier before testing.", self)
       text.resize(text.sizeHint())
       text.move(120, 40)
       self.setWindowIcon(QIcon('icon.png'))
       self.setWindowTitle('Digits')
       self.show()
   def trainning(self):
       global classifier
       classifier = dig.train_me(10000, 100, 1000)
       classifier.evaluate(dig.test_data, dig.test_labels)
       self.statusBar().showMessage('Accuracy: ' + 
                                    str(classifier.evaluate(dig.test_data, 
       dig.test_labels)['accuracy']))
   def testing(self):
       dig.i = 2
       dig.test_me(classifier)
       self.statusBar().showMessage("Predicted %d, label: %d" % (dig.prediction, dig.test_labels[dig.i]))
   def file_import(self):
           name = QFileDialog.getOpenFileName(self, 'Import File')
           print(name)
   def closeEvent(self, event):
       reply = QMessageBox.question(self, 'Message', "Are you sure you want to exit ?", 
                                QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
       if reply == QMessageBox.Yes:
           event.accept()
       else:
           event.ignore()
   def center(self):
       qr = self.frameGeometry()
       cp = QDesktopWidget().availableGeometry().center()
       qr.moveCenter(cp)
       self.move(qr.topLeft())
if __name__ == '__main__':
   app = QApplication(sys.argv)
   ui = Gui()
   sys.exit(app.exec_())

回答:

已解决:

Tensorflow只接受一维数组,而我的图像是三维数组,形状为[28, 28, 3]。所以我移除了RGB维度,并展平了二维数组。

然后我将结果导入到Tensorflow分类器中,但意识到我需要反转颜色,因此数组中的每个零应等于1,每个1应等于零。

这是代码:

    im = mpimg.imread('dva-test.png')
    im = im[:, :, 0]
    im = im.ravel()
    for j in range(len(im)):
        if im[j] == 0:
            im[j] = 1
        elif im[j] == 1:
            im[j] = 0 
    global prediction
    prediction = cls.predict(np.array([im], dtype=float), as_iterable=False)

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注