我在进行一个关于Python机器学习的学校项目。我使用Tensorflow创建了一个线性分类器,并以超过90%的准确率学习了MNIST数据集。
预测数据集的测试数据没有问题,但问题在于我想导入不在测试数据集中的数据(可能是用画图软件创建的一张图片)。
我为我的展示创建了一个简单的GUI,它在使用时也运行正常,但例如使用.png图片时就不行了。
我尝试过使用Pillow,但看起来效果不好。
你能帮帮我吗?我会接受任何建议。非常感谢。
这是Tensorflow的代码:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
learn = tf.contrib.learn
tf.logging.set_verbosity(tf.logging.ERROR)
global i, test_labels
i = 0
def display(i):
img = test_data[i]
plt.title('Example %d, label %d' % (i, test_labels[i]))
plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r)
plt.show()
global mnist
mnist = learn.datasets.load_dataset("mnist")
test_data = mnist.test.images
test_labels = np.array(mnist.test.labels, dtype=np.int32)
def train_me(max_examples, batch, step):
data = mnist.train.images
labels = np.array(mnist.train.labels, dtype=np.int32)
data = data[:max_examples]
labels = labels[:max_examples]
feature_columns = learn.infer_real_valued_columns_from_input(data)
cls = learn.LinearClassifier(feature_columns=feature_columns,
n_classes=10)
cls.fit(data, labels, batch_size=batch, steps=step)
return cls
def test_me(cls):
im = Image.open("dva-test.png")
global prediction
prediction = cls.predict(im, as_iterable=False)
这是GUI的代码:
import sys
import digits as dig
from PyQt5.QtWidgets import (QApplication, QWidget, QToolTip,
QPushButton, QMessageBox, QDesktopWidget, QMainWindow,
QLabel, QAction, QFileDialog)
from PyQt5.QtGui import QIcon
class Gui(QMainWindow):
def __init__(self):
super().__init__()
self.init_ui()
def init_ui(self):
self.setFixedSize(500, 200)
self.center()
self.statusBar().showMessage('Not trained')
exAct = QAction('Exit', self)
exAct.setShortcut('Ctrl+Q')
exAct.triggered.connect(self.close)
impAct = QAction('Import picture', self)
impAct.setShortcut('Ctrl+I')
impAct.triggered.connect(self.file_import)
menubar = self.menuBar()
fileMenu = menubar.addMenu('&File')
fileMenu.addAction(impAct)
fileMenu.addAction(exAct)
trainBtn = QPushButton('Train', self)
trainBtn.resize(trainBtn.sizeHint())
trainBtn.move(155, 120)
trainBtn.clicked.connect(self.trainning)
testBtn = QPushButton('Test', self)
testBtn.resize(trainBtn.sizeHint())
testBtn.move(255, 120)
testBtn.clicked.connect(self.testing)
text = QLabel("Please import file and train the classifier before testing.", self)
text.resize(text.sizeHint())
text.move(120, 40)
self.setWindowIcon(QIcon('icon.png'))
self.setWindowTitle('Digits')
self.show()
def trainning(self):
global classifier
classifier = dig.train_me(10000, 100, 1000)
classifier.evaluate(dig.test_data, dig.test_labels)
self.statusBar().showMessage('Accuracy: ' +
str(classifier.evaluate(dig.test_data,
dig.test_labels)['accuracy']))
def testing(self):
dig.i = 2
dig.test_me(classifier)
self.statusBar().showMessage("Predicted %d, label: %d" % (dig.prediction, dig.test_labels[dig.i]))
def file_import(self):
name = QFileDialog.getOpenFileName(self, 'Import File')
print(name)
def closeEvent(self, event):
reply = QMessageBox.question(self, 'Message', "Are you sure you want to exit ?",
QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
if reply == QMessageBox.Yes:
event.accept()
else:
event.ignore()
def center(self):
qr = self.frameGeometry()
cp = QDesktopWidget().availableGeometry().center()
qr.moveCenter(cp)
self.move(qr.topLeft())
if __name__ == '__main__':
app = QApplication(sys.argv)
ui = Gui()
sys.exit(app.exec_())
回答:
已解决:
Tensorflow只接受一维数组,而我的图像是三维数组,形状为[28, 28, 3]。所以我移除了RGB维度,并展平了二维数组。
然后我将结果导入到Tensorflow分类器中,但意识到我需要反转颜色,因此数组中的每个零应等于1,每个1应等于零。
这是代码:
im = mpimg.imread('dva-test.png')
im = im[:, :, 0]
im = im.ravel()
for j in range(len(im)):
if im[j] == 0:
im[j] = 1
elif im[j] == 1:
im[j] = 0
global prediction
prediction = cls.predict(np.array([im], dtype=float), as_iterable=False)