我有一个数据集,是通过以下代码使用tf.data.Dataset
创建的:
dataset = Dataset.from_tensor_slices(corona_new)dataset = dataset.window(WINDOW_SIZE, 1, drop_remainder=True)dataset = dataset.flat_map(lambda x: x.batch(WINDOW_SIZE))dataset = dataset.map(lambda x: tf.transpose(x))for i in dataset: print(i.numpy()) break
当我运行它时,得到的输出如下(这是一个批次的示例):
[[ 0. 125. 111. 232. 164. 134. 235. 190.] [ 0. 14. 16. 7. 9. 7. 6. 8.] [ 0. 132. 199. 158. 148. 141. 179. 174.] [ 0. 0. 0. 2. 0. 2. 1. 2.] [ 0. 0. 0. 0. 3. 5. 0. 0.]]
如何取消它们的批处理?
回答:
找到了我的解决方案。
在TensorFlow 2.0中,您可以通过调用.unbatch()
函数来取消tf.data.Dataset
的批处理。
示例:dataset.unbatch()