如何在Tensorflow中进行批处理迭代?

我试图在我的数据上重用PTB语言模型,但缺乏对Tensorflow的了解,无法理解它如何处理训练数据的批处理迭代。我理解的训练期间的批处理迭代方式如下:

while epoch <= maxepoch do  for minibatch in data_iterator() do    model.forward(minibatch)    (...)  endend

这已经不能再简单了,对吗?许多其他框架中也有类似的做法,但在Tensorflow中却没有 🙂 以下是来自官方PTB语言模型教程的minibatch函数示例:

def ptb_producer(raw_data, batch_size, num_steps, name=None):    with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):        raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)        data_len = tf.size(raw_data)        batch_len = data_len // batch_size        data = tf.reshape(raw_data[0 : batch_size * batch_len],                                            [batch_size, batch_len])        epoch_size = (batch_len - 1) // num_steps        assertion = tf.assert_positive(                epoch_size,                message="epoch_size == 0, decrease batch_size or num_steps")        with tf.control_dependencies([assertion]):            epoch_size = tf.identity(epoch_size, name="epoch_size")        i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()        x = tf.strided_slice(data, [0, i * num_steps], [batch_size, (i + 1) * num_steps])        x.set_shape([batch_size, num_steps])        y = tf.strided_slice(data, [0, i * num_steps + 1], [batch_size, (i + 1) * num_steps + 1])        y.set_shape([batch_size, num_steps])        return x, y

这个函数在被调用时返回x输入和y目标。我在这里没有看到Python迭代器的迹象,但有对tf.strided_slice的调用,它使用了由tf.train.range_input_producer生成的i索引,因此这应该模拟了一个滑动窗口来遍历数据。然而,这个函数在训练前只被调用一次,那么它如何遍历我的数据呢?这一点不清楚。有人能解释这种“魔法”和完全不透明的Tensorflow机制吗?


回答:

这个“魔法”隐藏在调用tf.train.range_input_producer的行中:

i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()

… 这创建了一个操作,它从队列中弹出值,该队列持有0..epoch_size-1的整数。换句话说,它遍历0..epoch_size-1的范围。


是的,这似乎违反直觉。所以这里有一个在tensorflow中使用队列的简单可运行示例:

index = tf.train.range_input_producer(10, shuffle=False).dequeue()with tf.Session() as sess:  coord = tf.train.Coordinator()  threads = tf.train.start_queue_runners(coord=coord)  for i in range(15):    print(sess.run(index))  coord.request_stop()  coord.join(threads)

运行后,您应该会看到从09的值,然后再看到5个从04的值。请注意,sess.run评估相同的张量index,但每次都得到不同的值。可以添加更多依赖于index的操作,它们将使用index的新值进行评估。

还请注意,队列在另一个线程中运行,因此为了使用tf.train.range_input_producer,必须启动一个Coordinator并生成多个线程(并在最后停止它们)。如果您尝试在没有Coordinator的情况下运行相同的示例,sess.run(index)将会阻塞脚本执行。

您可以尝试这个示例,例如,设置shuffle=True等。


回到PTB生产者代码片段:

i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()x = tf.strided_slice(data, [0, i*num_steps], [batch_size, (i+1)*num_steps])x.set_shape([batch_size, num_steps])y = tf.strided_slice(data, [0, i*num_steps+1], [batch_size, (i+1)*num_steps+1])y.set_shape([batch_size, num_steps])

现在应该清楚,尽管xy被定义为简单的张量,但它们实际上是data切片的迭代器。所有的线程工作都由tf.train.Supervisor处理。因此,调用依赖于xy的优化操作将自动获取新的批次。


建议阅读:

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注