背景
Tensorflow中典型的输入管道如下所示:
tf.train.string_input_producer(list_of_filenames) (创建文件名队列) | \|/ 固定长度读取器从文件中读取记录 | \|/ 读取的记录被解码和处理(例如,如果处理图像,则进行裁剪、翻转等) | \|/ tf.train.shuffle_batch(tensors,num_threads) (创建一个混洗队列并返回张量批次)
问题
问1)在函数tf.train.string_input_producer()中没有num_threads参数。这是否意味着只有一个线程专门用于从文件名队列中读取文件名?
问2)函数tf.train.shuffle_batch()中的num_threads参数的作用范围是什么?即,这里提到的线程数是否也用于读取、解码和处理文件,还是仅用于创建张量批次?
问3)有没有一种方法可以打印出哪个线程读取了特定文件的文件名或记录?即某种程度上记录每个线程的工作情况?
回答:
所有的数据加载操作都在tensorflow图中执行,你需要启动一个或多个线程来迭代读取器/入队操作。Tensorflow提供了一个QueueRunner类来完成这一任务。Coordinator类可以非常简单地管理这些线程。
https://www.tensorflow.org/programmers_guide/threading_and_queues
这是上述链接中的示例代码:
# 创建一个队列运行器,将并行运行4个线程来入队示例。qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)# 启动图。sess = tf.Session()# 创建一个协调器,启动队列运行器线程。coord = tf.train.Coordinator()enqueue_threads = qr.create_threads(sess, coord=coord, start=True)# 运行训练循环,使用协调器控制终止。for step in xrange(1000000): if coord.should_stop(): break sess.run(train_op)# 完成后,请求线程停止。coord.request_stop()# 并等待它们实际停止。coord.join(enqueue_threads)
如果你在图外(在你自己的代码中,而不是使用TF操作)加载/预处理样本,那么你不会使用QueueRunner,而是使用你自己的类来使用sess.run(enqueue_op, feed_dict={...})
命令在循环中入队数据。
答1:线程数通过qr.create_threads(sess, coord=coord, start=True)
处理。
答2:TF会话是线程安全的,每次调用tf.run(...)
时,都会看到从开始时的当前变量的一致快照。你的QueueRunner入队操作可以运行任意数量的线程。它们会以线程安全的方式排队。
答3:我自己没有使用过tf.train.string_input_producer
,但我认为你需要在图的后期请求一个张量来dequeued
数据,只需将该张量添加到你的sess.run([train_op, dequeue_op])
请求列表中即可。