简短版本
假设有一个生成器采样,例如输出 3
个输入和 1
个标签,如何定义我的 Tensorflow 数据集管道以获得 K * 3
个输入和 K * 1
个标签的批次?
详细版本
背景
我正在使用三元组网络,并希望调整我的当前输入管道以使用 Tensorflow 数据集。
在我的案例中,一个批次包含 N
个元素(例如图片)和 N // 3
个标签(假设 N % 3 == 0
),每个标签应用于连续的 3 个输入,例如:
labels = [compute_label(inputs[3*i], inputs[3*i+1], inputs[3*i+2]) for i in range(N // 3)]
其中 compute_label(*args)
是一个简单的函数,可以使用 Tensorflow 操作或基本 Python 实现。
为了使事情更加复杂,输入元素必须以三连的方式采样(例如,我们希望 inputs[3*i]
与 inputs[3*i+1]
相似,而与 inputs[3*i+2]
不相似):
for i in range(N // 3): inputs[3*i], inputs[3*i+1], inputs[3*i+2] = sample_triplet(i)
问题
针对我的具体情况重述简短问题:
给定这两个函数 sample_triplet()
和 compute_label()
,我如何使用 Tensorflow 数据集构建输入管道,以构建包含 N
个输入和 N // 3
个标签的批次?
我尝试了多种 tf.data.Dataset.from_generator()
和 tf.data.Dataset.flat_map()
的组合,但无法找到一种方法既能将批次输入从 N // 3
个三元组展平为 N
个样本,又能仅输出 N // 3
个批次标签。
我找到的一个解决方案是“作弊”,在 tf.data.Dataset.from_generator()
中计算我的标签,并将每个标签重复 3 次,以便能够在三元组输入 + 标签上使用 tf.data.Dataset.flat_map()
。作为批次后处理步骤,我随后将 N
个重复的标签“压缩”回 N // 3
个。
当前解决方案的示例
import tensorflow as tfimport numpy as npdef sample_triplet(): # 采样我们的元素,这里作为 [class, random_val] 元素: anchor_class = puller_class = pusher_class = np.random.randint(0, 10) while pusher_class == anchor_class: # 我们希望推动者属于不同的类别 pusher_class = np.random.randint(0, 10) anchor = np.array([anchor_class, np.random.randint(0, 5)]) puller = np.array([puller_class, np.random.randint(0, 5)]) pusher = np.array([pusher_class, np.random.randint(0, 5)]) # 堆叠三元组,以便之后作为批次进行 flat_map: triplet_inputs = np.stack((anchor, puller, pusher), axis=0) # 同时返回类别以便之后计算标签: triplet_classes = np.stack((anchor_class, puller_class, pusher_class), axis=0) return triplet_inputs, triplet_classesdef compute_labels(triplet_classes): # 计算标签,例如锚点与推动者类别之间的距离: label = np.abs(triplet_classes[0] - triplet_classes[2]) return labeldef triplet_generator(): while True: triplet = sample_triplet() # 当前解决方案:在这里也计算标签, # 将其堆叠 3 次以便 flat_map 工作, # 然后删除重复项: triplet_inputs = triplet[0] triplet_label = compute_labels(triplet[1]) yield triplet_inputs, np.stack((triplet_label, triplet_label, triplet_label), axis=0)def squeeze_triplet_labels(*batch): # 删除重复的标签, # 从批次 (N 个输入, N 个标签) 变为 (N 个输入, N // 3 个标签) squeezed_labels = batch[-1][::3] new_batch = (*batch[:-1], squeezed_labels) return new_batchbatch_size = 30assert(batch_size % 3 == 0)sess = tf.InteractiveSession()train_dataset = (tf.data.Dataset .from_generator(triplet_generator, (tf.int32, tf.float32), ([3, 2], [3])) .flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)) .batch(batch_size))next_training_batch = train_dataset.make_one_shot_iterator().get_next()next_proper_training_batch = squeeze_triplet_labels(*next_training_batch)batch = sess.run(next_proper_training_batch)print("输入形状: {} ; 标签形状: {}".format(batch[0].shape, batch[1].shape))# >> 输入形状: (30, 2) ; 标签形状: (10,)
回答:
一个简单的解决方案可以是创建两个数据集对象,一个用于标签,一个用于数据,然后按每组 3 个批处理数据,并使用 tf.data.interleave
将两个数据集重新组合在一起,产生你想要的结果。
如果这不容易做到,那么你可以尝试以下将一个元素映射到多个元素的过程。你需要创建一个包含 3 个元素的批次(带有 3 个标签),然后在映射函数中将其拆分为 3 组数据,每组对应你接收到的一个标签。执行此操作的配方可以在以下 Stack Overflow 问题中找到,尽管它比第一个建议稍微复杂一些: