我在学习机器学习的过程中,发现了一个使用MNIST数据集的Tensorflow CNN代码。这里有一段我想了解的代码。
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), reduction_indices=[1]))train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess.run(tf.global_variables_initializer())for i in range(1000): batch = mnist.train.next_batch(100) if i%100 == 0: train_accuracy = accuracy.eval(feed_dict={ x:batch[0], y_: batch[1], keep_prob: 1.0}) print("step %d, training accuracy %g"%(i, train_accuracy)) train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})print("test accuracy %g"%accuracy.eval(feed_dict={ x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
在这个代码中,我的疑问在于batch = mnist.train.next_batch(100)。经过查找,我了解到这表示这是小批量处理,并且从MNIST数据集中随机选择100个数据。现在我的问题是:
- 如果我想用完整批次测试这个代码,我应该怎么做?只需将mnist.train.next_batch(100)改为mnist.train.next_batch(55000)吗?
回答:
是的,获取55000个数据的批次将会在MNIST的所有数字上训练一个周期。
需要注意的是,这不是一个好主意:这很可能无法适应您的内存。您需要保存55000个数字的权重激活和梯度……您的Python很可能会崩溃!
通过在100个随机图像的批次上训练1000次,您将得到很好的结果,并且您的计算机会很开心!