我在使用Chainer进行手写数字识别。我已经在MNIST数据库上训练了一个模型。然而,由于某种原因,我无法对单个示例进行分类。我猜想可能是示例的格式选择不当,但我已经尝试了多种方法,仍然无法解决这个问题。如果这是一个显而易见的问题,请原谅,我在Python方面经验不足。
代码:
import chainerimport chainer.functions as Fimport chainer.links as Limport osimport sysfrom chainer.training import extensionsimport numpy as npfrom tkinter.filedialog import askopenfilenamefrom PIL import Imagefrom chainer import serializersfrom chainer.dataset import concat_examplesfrom chainer.backends import cudafrom chainer import Function, gradient_check, report, training, utils, Variablefrom chainer import datasets, iterators, optimizers, serializersfrom chainer import Link, Chain, ChainListCONST_RESUME = ''class Network(chainer.Chain): def __init__(self, n_units, n_out): super(Network, self).__init__() with self.init_scope(): self.l1 = L.Linear(None, n_units) self.l2 = L.Linear(None, n_units) self.l3 = L.Linear(None, n_out) def __call__(self, x): h1 = F.sigmoid(self.l1(x)) h2 = F.sigmoid(self.l2(h1)) return self.l3(h2)def query_yes_no(question, default="no"): """Ask a yes/no question via raw_input() and return their answer. "question" is a string that is presented to the user. "default" is the presumed answer if the user just hits <Enter>. It must be "yes" (the default), "no" or None (meaning an answer is required of the user). The "answer" return value is True for "yes" or False for "no". """ valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} if default is None: prompt = " [y/n] " elif default == "yes": prompt = " [Y/n] " elif default == "no": prompt = " [y/N] " else: raise ValueError("invalid default answer: '%s'" % default) while True: sys.stdout.write(question + prompt) choice = input().lower() if default is not None and choice == '': return valid[default] elif choice in valid: return valid[choice] else: sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n") return np.argmax(y.data)def main(): #file_list = [None]*10 #for i in range(10): # file_list[i] = open('data{}.txt'.format(i), 'rb') print('MNIST digit recognition.') usr_in : str model = L.Classifier(Network(100, 10)) chainer.backends.cuda.get_device_from_id(0).use() model.to_gpu() usr_in = input('Input (t) to train, (l) to load.') while len(usr_in) != 1 and (usr_in[0] != 'l' or usr_in[0] != 't'): print('Invalid input.') usr_in = input() if usr_in[0] == 't': optimizer = chainer.optimizers.Adam() optimizer.setup(model) train, test = chainer.datasets.get_mnist() train_iter = chainer.iterators.SerialIterator(train, batch_size=100, shuffle=True) test_iter = chainer.iterators.SerialIterator(test, 100, repeat=False, shuffle=False) updater = training.updaters.StandardUpdater(train_iter, optimizer, device=0) trainer = training.Trainer(updater, (10, 'epoch'), out='out.txt') trainer.extend(extensions.Evaluator(test_iter, model, device=0)) trainer.extend(extensions.dump_graph('main/loss')) frequency = 1 trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch')) trainer.extend(extensions.LogReport()) if extensions.PlotReport.available(): trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) trainer.extend(extensions.ProgressBar()) if CONST_RESUME: chainer.serializers.load_npz(CONST_RESUME, trainer) trainer.run() ans = query_yes_no('Would you like to save this network?') if ans: usr_in = input('Input filename: ') serializers.save_npz('{}.{}'.format(usr_in, 'npz'), model) elif usr_in[0] == 'l': filename = askopenfilename(initialdir=os.getcwd(), title='Choose a file') serializers.load_npz(filename, model) else: return while True: ans = query_yes_no('Would you like to evaluate an image with the current network?') if ans: filename = askopenfilename(initialdir=os.getcwd(), title='Choose a file') file = Image.open(filename) bw_file = file.convert('L') size = 28, 28 bw_file.thumbnail(size, Image.ANTIALIAS) pix = bw_file.load() x = np.empty([28 * 28]) for i in range(28): for j in range(28): x[i * 28 + j] = pix[i, j] #gpu_id = 0 #batch = (, gpu_id) x = (x.astype(np.float32))[None, ...] y = model(x) print('predicted_label:', y.argmax(axis=1)[0]) else: returnmain()
回答:
你能展示一下错误消息吗?否则很难猜测错误的原因。
但我猜测输入的形状可能不同。当你使用Chainer内置函数获取数据集时,train, test = chainer.datasets.get_mnist()
这些数据集图像的形状是(minibatch, channel, height, width)
。但看起来你构建的输入x
的形状是(minibatch, height * width = 28*28=784)
,这是不同的形状?
你也可以参考一些Chainer的教程,