批处理看起来比feed_dicts更简洁,所以我正在尝试理解Tensorflow中的批处理。
下面的代码块是否在批次中创建了32个相同的图像,然后将其馈送到队列中?
# 创建32个图像和32个标签的批次
image_batch, label_batch = tf.train.shuffle_batch(
[single_image, single_label],
batch_size=32,
num_threads=4,
capacity=50000,
min_after_dequeue=10000)
一些背景信息:我目前有一个包含大约5万行的单一文件。我使用tf.train.string_input_producer
和tf.decode_csv
从csv文件中读取行,但我对向tf.train.shuffle_batch
提供什么作为参数感到困惑,是单独的行还是包含从文件中读取的所有行的张量。
回答:
在您的问题中的代码片段中,张量single_image
和single_label
对应于一个图像及其相关的标签。从tf.train.shuffle_batch()
返回的张量image_batch
和label_batch
对应于32个可能不同的图像打包在一起,以及32个相关的标签。TensorFlow内部使用tf.RandomShuffleQueue
来打乱数据,并创建额外的线程来评估single_image
和single_label
,以便它们可以被添加到这个队列中。
tf.train.shuffle_batch()
函数根据您传递的参数具有不同的行为。例如,如果您传递enqueue_many=True
,那么tensors
参数中的张量将被解释为元素的批次,TensorFlow将在前导维度上将它们连接起来(因此每个张量在第0维度上必须具有相同的大小)。使用enqueue_many=True
,您可以将整个数据集传递给tf.train.shuffle_batch()
,或者您可以传递元素的批次(例如,使用tf.ReaderBase.read_up_to()
)。
* 我所说的“可能不同的”是因为您使用的是批处理函数的打乱版本,如果您的数据集相对于capacity
和min_after_dequeue
参数较小,那么在单个批次中看到同一示例的多个副本是可能的。