Tensorflow CNN图像增强管道

我在学习新的Tensorflow API时,遇到了一些困惑,不知道如何获取我的输入批次张量,以便使用例如tf.image进行操作和增强。

这是我当前的网络和管道:

trainX, testX, trainY, testY = read_data()# trainX [num_image, height, width, channels],这些是numpy数组#...train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))#...iterator = tf.data.Iterator.from_structure(train_dataset.output_types,                  train_dataset.output_shapes)features, labels = iterator.get_next()train_init_op = iterator.make_initializer(train_dataset)test_init_op = iterator.make_initializer(test_dataset)#...定义CNN架构...# 在训练循环中TrainLoop {   sess.run(train_init_op)  # 切换到训练数据   sess.run(train_step, ...) # 运行训练步骤   #...    sess.run(test_init_op)  # 切换到测试数据   test_loss = sess.run(loss, ...) # 在一个epoch后打印测试损失}

我使用Dataset API创建了两个数据集,以便在训练循环中计算并记录训练和测试损失。

在这个管道中,我应该在哪里操作和扭曲我的输入图像批次?我没有为trainX输入批次创建任何tf.placeholders,因此我无法使用tf.image来操作它们,因为例如tf.image.flip_up_down需要一个3-D或4-D张量。

  • 使用新API实现这个管道的自然方法是什么?
  • 在这个管道中,是否有模块或简单的方法可以增强用于训练的输入图像批次?

回答:

最近发布了一篇非常好的文章和一个演讲,它们比我在这里的回答更详细地介绍了API。这是一个简短的例子:

import tensorflow as tfimport numpy as npdef read_data():    n_train = 100    n_test = 50    height = 20    width = 30    channels = 3    trainX = (np.random.random(        size=(n_train, height, width, channels)) * 255).astype(np.uint8)    testX = (np.random.random(            size=(n_test, height, width, channels))*255).astype(np.uint8)    trainY = (np.random.random(size=(n_train,))*10).astype(np.int32)    testY = (np.random.random(size=(n_test,))*10).astype(np.int32)    return trainX, testX, trainY, testYtrainX, testX, trainY, testY = read_data()# trainX [num_image, height, width, channels],这些是numpy数组train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))test_dataset = tf.data.Dataset.from_tensor_slices((testX, testY))def map_single(x, y):    print('Map single:')    print('x shape: %s' % str(x.shape))    print('y shape: %s' % str(y.shape))    x = tf.image.per_image_standardization(x)    # 考虑:x = tf.image.random_flip_left_right(x)    return x, ydef map_batch(x, y):    print('Map batch:')    print('x shape: %s' % str(x.shape))    print('y shape: %s' % str(y.shape))    # 注意:这会将所有图像左右翻转。不知道这是不是你想要的    # 更新:看起来tf文档有误,你需要一个3D张量?    # return tf.image.flip_left_right(x), y    return x, ybatch_size = 32train_dataset = train_dataset.repeat().shuffle(100)train_dataset = train_dataset.map(map_single, num_parallel_calls=8)train_dataset = train_dataset.batch(batch_size)train_dataset = train_dataset.map(map_batch)train_dataset = train_dataset.prefetch(2)test_dataset = test_dataset.map(        map_single, num_parallel_calls=8).batch(batch_size).map(map_batch)test_dataset = test_dataset.prefetch(2)iterator = tf.data.Iterator.from_structure(train_dataset.output_types,                  train_dataset.output_shapes)features, labels = iterator.get_next()train_init_op = iterator.make_initializer(train_dataset)test_init_op = iterator.make_initializer(test_dataset)with tf.Session() as sess:    sess.run(train_init_op)    feat, lab = sess.run((features, labels))    print(feat.shape)    print(lab.shape)    sess.run(test_init_op)    feat, lab = sess.run((features, labels))    print(feat.shape)    print(lab.shape)    

一些注意事项:

  1. 这种方法依赖于能够将整个数据集加载到内存中。如果不能,考虑使用tf.data.Dataset.from_generator。如果你的shuffle缓冲区很大,这可能会导致shuffle时间变慢。我更喜欢的方法是将一些keys张量完全加载到内存中——它可能只是每个示例的索引——然后使用tf.py_func将该键值映射到数据值。这比转换到tfrecords稍低效,但通过prefetching可能不会影响性能。由于映射之前完成了洗牌,你只需要将shuffle_buffer键加载到内存中,而不是shuffle_buffer示例。
  2. 要增强你的数据集,使用tf.data.Dataset.map,根据你是否希望应用批次操作(在4D图像张量上工作)或元素操作(3D图像张量),在批次操作之前或之后使用。请注意,看起来tf.image.flip_left_right的文档已经过时,因为当我尝试使用4D张量时会出现错误。如果你想随机增强你的数据,使用tf.image.random_flip_left_right而不是tf.image.flip_left_right
  3. 如果你使用的是tf.estimator.Estimator(或者不介意将你的代码转换为使用它),那么查看tf.estimator.train_and_evaluate,以获取内置的在数据集之间切换的方法。
  4. 考虑使用shuffle/repeat方法来洗牌/重复你的数据集。关于效率,请参见这篇文章。特别是,repeat -> shuffle -> map -> batch -> batch-wise map -> prefetch 似乎是大多数应用的最佳操作顺序。

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注