我在学习新的Tensorflow API时,遇到了一些困惑,不知道如何获取我的输入批次张量,以便使用例如tf.image进行操作和增强。
这是我当前的网络和管道:
trainX, testX, trainY, testY = read_data()# trainX [num_image, height, width, channels],这些是numpy数组#...train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))#...iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)features, labels = iterator.get_next()train_init_op = iterator.make_initializer(train_dataset)test_init_op = iterator.make_initializer(test_dataset)#...定义CNN架构...# 在训练循环中TrainLoop { sess.run(train_init_op) # 切换到训练数据 sess.run(train_step, ...) # 运行训练步骤 #... sess.run(test_init_op) # 切换到测试数据 test_loss = sess.run(loss, ...) # 在一个epoch后打印测试损失}
我使用Dataset API创建了两个数据集,以便在训练循环中计算并记录训练和测试损失。
在这个管道中,我应该在哪里操作和扭曲我的输入图像批次?我没有为trainX输入批次创建任何tf.placeholders,因此我无法使用tf.image来操作它们,因为例如tf.image.flip_up_down
需要一个3-D或4-D张量。
- 使用新API实现这个管道的自然方法是什么?
- 在这个管道中,是否有模块或简单的方法可以增强用于训练的输入图像批次?
回答:
最近发布了一篇非常好的文章和一个演讲,它们比我在这里的回答更详细地介绍了API。这是一个简短的例子:
import tensorflow as tfimport numpy as npdef read_data(): n_train = 100 n_test = 50 height = 20 width = 30 channels = 3 trainX = (np.random.random( size=(n_train, height, width, channels)) * 255).astype(np.uint8) testX = (np.random.random( size=(n_test, height, width, channels))*255).astype(np.uint8) trainY = (np.random.random(size=(n_train,))*10).astype(np.int32) testY = (np.random.random(size=(n_test,))*10).astype(np.int32) return trainX, testX, trainY, testYtrainX, testX, trainY, testY = read_data()# trainX [num_image, height, width, channels],这些是numpy数组train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))def map_single(x, y): print('Map single:') print('x shape: %s' % str(x.shape)) print('y shape: %s' % str(y.shape)) x = tf.image.per_image_standardization(x) # 考虑:x = tf.image.random_flip_left_right(x) return x, ydef map_batch(x, y): print('Map batch:') print('x shape: %s' % str(x.shape)) print('y shape: %s' % str(y.shape)) # 注意:这会将所有图像左右翻转。不知道这是不是你想要的 # 更新:看起来tf文档有误,你需要一个3D张量? # return tf.image.flip_left_right(x), y return x, ybatch_size = 32train_dataset = train_dataset.repeat().shuffle(100)train_dataset = train_dataset.map(map_single, num_parallel_calls=8)train_dataset = train_dataset.batch(batch_size)train_dataset = train_dataset.map(map_batch)train_dataset = train_dataset.prefetch(2)test_dataset = test_dataset.map( map_single, num_parallel_calls=8).batch(batch_size).map(map_batch)test_dataset = test_dataset.prefetch(2)iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)features, labels = iterator.get_next()train_init_op = iterator.make_initializer(train_dataset)test_init_op = iterator.make_initializer(test_dataset)with tf.Session() as sess: sess.run(train_init_op) feat, lab = sess.run((features, labels)) print(feat.shape) print(lab.shape) sess.run(test_init_op) feat, lab = sess.run((features, labels)) print(feat.shape) print(lab.shape)
一些注意事项:
- 这种方法依赖于能够将整个数据集加载到内存中。如果不能,考虑使用
tf.data.Dataset.from_generator
。如果你的shuffle缓冲区很大,这可能会导致shuffle时间变慢。我更喜欢的方法是将一些keys
张量完全加载到内存中——它可能只是每个示例的索引——然后使用tf.py_func
将该键值映射到数据值。这比转换到tfrecords
稍低效,但通过prefetching
可能不会影响性能。由于映射之前完成了洗牌,你只需要将shuffle_buffer
键加载到内存中,而不是shuffle_buffer
示例。 - 要增强你的数据集,使用
tf.data.Dataset.map
,根据你是否希望应用批次操作(在4D图像张量上工作)或元素操作(3D图像张量),在批次操作之前或之后使用。请注意,看起来tf.image.flip_left_right
的文档已经过时,因为当我尝试使用4D张量时会出现错误。如果你想随机增强你的数据,使用tf.image.random_flip_left_right
而不是tf.image.flip_left_right
。 - 如果你使用的是
tf.estimator.Estimator
(或者不介意将你的代码转换为使用它),那么查看tf.estimator.train_and_evaluate
,以获取内置的在数据集之间切换的方法。 - 考虑使用
shuffle
/repeat
方法来洗牌/重复你的数据集。关于效率,请参见这篇文章。特别是,repeat -> shuffle -> map -> batch -> batch-wise map -> prefetch 似乎是大多数应用的最佳操作顺序。