我最近开始使用tensorflow学习CNN,发现tfrecords在加速训练方面非常有帮助,但我正在与数据API作斗争。
解析后,我的数据集由(图像,标签)元组组成,这对于训练来说很好,但是我试图在另一个数据集中提取图像,以便调用keras.predict()。
我尝试了这个解决方案:
test_set = get_set_tfrecord(test_path, _parse_function, num_parallel_calls = 4)lab = []f = Truefor image, label in test_set.take(600): if f: img = tf.data.Dataset.from_tensors(image) f = False else: img = img.concatenate(tf.data.Dataset.from_tensors(image)) lab.append(label.numpy())
这个代码简单粗糙,不是很好,但它有效,除了为了执行连接(即堆叠)操作,它会将每个图像加载到RAM中。
正确的做法是什么?
回答:
你可以使用tf.data.Dataset
中的map
API。你可以编写以下代码。
result = test_set.map(lambda image, label: image)# 你可以迭代并检查最终收到的内容。# 我期望只有图像。for image in result.take(1): print(image)
我希望使用上述代码能解决你的问题,并且这个答案对你有帮助。