使用Tensorflow在MNIST数据集上训练一个带有隐藏层的全连接网络

我刚刚开始学习Tensorflow的机器学习,在完成MNIST初学者教程后,我想通过添加一个隐藏层来稍微提高那个简单模型的准确性。基本上,我决定直接复制Michael Nielsen关于神经网络和深度学习的书的第一章中的网络架构(见这里)。

Nielsen的代码对我来说运行得很好,但是使用以下Tensorflow代码时,我没有得到类似的结果。如果我没有弄错的话,它应该完全实现了Nielsen提出的模型:

from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfmnist = input_data.read_data_sets("MNIST_data/", one_hot=True)def weight_variable(shape):    initial = tf.random_normal(shape)    return tf.Variable(initial)def bias_variable(shape):    initial = tf.random_normal(shape)    return tf.Variable(initial)x = tf.placeholder(tf.float32, [None, 784])#hidden layerW_fc1 = weight_variable([784, 30])b_fc1 = bias_variable([30])h_fc1 = tf.sigmoid(tf.matmul(x, W_fc1) + b_fc1)#output layerW_fc2 = weight_variable([30, 10])b_fc2 = bias_variable([10])y = tf.sigmoid(tf.matmul(h_fc1, W_fc2) + b_fc2)y_ = tf.placeholder(tf.float32, [None, 10])loss = tf.reduce_mean(tf.reduce_sum(tf.pow(y_ - y, 2), reduction_indices=[1])) #I also tried simply tf.nn.l2_loss(y_ - y)train_step = tf.train.GradientDescentOptimizer(3.0).minimize(loss)sess = tf.InteractiveSession()tf.global_variables_initializer().run()def get_accuracy():    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    return sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})for i in range(30):    batch_xs, batch_ys = mnist.train.next_batch(10)    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})    print("Epoch {} accuracy: {:.2f}%".format(i+1, get_accuracy() * 100))

在训练30个周期后,我得到的准确率约为17%。使用Nielsen的代码,仅经过一个周期的训练,我就能得到91%的准确率。

显然我遗漏了什么。我尝试过提高准确率,通过更长时间的训练,我设法将其提高到约60%,但即使使用不同的后端代码,同样的网络也应该给出相似的结果。我还尝试调整超参数,但没有达到任何可比的结果。

你能在我代码中找到任何错误吗?


回答:

正如suharshs提到的,看起来你的问题是由于对“epoch”一词的误解造成的。虽然不一定是严格的,但一个epoch通常是对整个训练数据集的一次迭代。如果你再看一下Nielsen的代码,你会看到这在SGD方法中有所体现。一个epoch涉及遍历整个training_data,这些数据被分成小批次。你的每个epoch实际上是小批次的大小,只有10个样本。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注