Chainer: 无法分类,训练模型(x)抛出错误

我在使用Chainer进行手写数字识别。我已经在MNIST数据库上训练了一个模型。然而,由于某种原因,我无法对单个示例进行分类。我猜想可能是示例的格式选择不当,但我已经尝试了多种方法,仍然无法解决这个问题。如果这是一个显而易见的问题,请原谅,我在Python方面经验不足。

这是错误的显示

代码:

import chainerimport chainer.functions as Fimport chainer.links as Limport osimport sysfrom chainer.training import extensionsimport numpy as npfrom tkinter.filedialog import askopenfilenamefrom PIL import Imagefrom chainer import serializersfrom chainer.dataset import concat_examplesfrom chainer.backends import cudafrom chainer import Function, gradient_check, report, training, utils, Variablefrom chainer import datasets, iterators, optimizers, serializersfrom chainer import Link, Chain, ChainListCONST_RESUME = ''class Network(chainer.Chain):    def __init__(self, n_units, n_out):        super(Network, self).__init__()        with self.init_scope():            self.l1 = L.Linear(None, n_units)            self.l2 = L.Linear(None, n_units)            self.l3 = L.Linear(None, n_out)    def __call__(self, x):        h1 = F.sigmoid(self.l1(x))        h2 = F.sigmoid(self.l2(h1))        return self.l3(h2)def query_yes_no(question, default="no"):    """Ask a yes/no question via raw_input() and return their answer.    "question" is a string that is presented to the user.    "default" is the presumed answer if the user just hits <Enter>.        It must be "yes" (the default), "no" or None (meaning        an answer is required of the user).    The "answer" return value is True for "yes" or False for "no".    """    valid = {"yes": True, "y": True, "ye": True,             "no": False, "n": False}    if default is None:        prompt = " [y/n] "    elif default == "yes":        prompt = " [Y/n] "    elif default == "no":        prompt = " [y/N] "    else:        raise ValueError("invalid default answer: '%s'" % default)    while True:        sys.stdout.write(question + prompt)        choice = input().lower()        if default is not None and choice == '':            return valid[default]        elif choice in valid:            return valid[choice]        else:            sys.stdout.write("Please respond with 'yes' or 'no' "                             "(or 'y' or 'n').\n")    return np.argmax(y.data)def main():    #file_list = [None]*10    #for i in range(10):    #    file_list[i] = open('data{}.txt'.format(i), 'rb')    print('MNIST digit recognition.')    usr_in : str    model = L.Classifier(Network(100, 10))    chainer.backends.cuda.get_device_from_id(0).use()    model.to_gpu()    usr_in = input('Input (t) to train, (l) to load.')    while len(usr_in) != 1 and (usr_in[0] != 'l' or usr_in[0] != 't'):        print('Invalid input.')        usr_in = input()    if usr_in[0] == 't':        optimizer = chainer.optimizers.Adam()        optimizer.setup(model)        train, test = chainer.datasets.get_mnist()        train_iter = chainer.iterators.SerialIterator(train, batch_size=100, shuffle=True)        test_iter = chainer.iterators.SerialIterator(test, 100, repeat=False, shuffle=False)        updater = training.updaters.StandardUpdater(train_iter, optimizer, device=0)        trainer = training.Trainer(updater, (10, 'epoch'), out='out.txt')        trainer.extend(extensions.Evaluator(test_iter, model, device=0))        trainer.extend(extensions.dump_graph('main/loss'))        frequency = 1        trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))        trainer.extend(extensions.LogReport())        if extensions.PlotReport.available():            trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'))            trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'))        trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))        trainer.extend(extensions.ProgressBar())        if CONST_RESUME:            chainer.serializers.load_npz(CONST_RESUME, trainer)        trainer.run()        ans = query_yes_no('Would you like to save this network?')        if ans:            usr_in = input('Input filename: ')            serializers.save_npz('{}.{}'.format(usr_in, 'npz'), model)    elif usr_in[0] == 'l':        filename = askopenfilename(initialdir=os.getcwd(), title='Choose a file')        serializers.load_npz(filename, model)    else:        return    while True:        ans = query_yes_no('Would you like to evaluate an image with the current network?')        if ans:            filename = askopenfilename(initialdir=os.getcwd(), title='Choose a file')            file = Image.open(filename)            bw_file = file.convert('L')            size = 28, 28            bw_file.thumbnail(size, Image.ANTIALIAS)            pix = bw_file.load()            x = np.empty([28 * 28])            for i in range(28):                for j in range(28):                    x[i * 28 + j] = pix[i, j]            #gpu_id = 0            #batch = (, gpu_id)            x = (x.astype(np.float32))[None, ...]            y = model(x)            print('predicted_label:', y.argmax(axis=1)[0])        else:            returnmain()

回答:

你能展示一下错误消息吗?否则很难猜测错误的原因。

但我猜测输入的形状可能不同。当你使用Chainer内置函数获取数据集时,train, test = chainer.datasets.get_mnist()这些数据集图像的形状是(minibatch, channel, height, width)。但看起来你构建的输入x的形状是(minibatch, height * width = 28*28=784),这是不同的形状?

你也可以参考一些Chainer的教程,

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注