Keras `ImageDataGenerator` 对图像和掩码的增强方式不同

我在使用TensorFlow后端的Keras训练语义分割模型。我采用了ImageDataGenerator来进行图像增强,包括旋转、翻转和位移。按照文档的指导,我创建了一个字典maskgen_args,并使用它作为参数来实例化两个ImageDataGenerator对象。

maskgen_args = dict(    rotation_range=90,    validation_split=VALIDATION_SPLIT)image_datagen = ImageDataGenerator(**maskgen_args)mask_datagen = ImageDataGenerator(**maskgen_args)

训练数据生成器的设置如下,通过将seed设置为相同的值,掩码将与图像匹配。

training_data_generator = zip(    image_datagen.flow_from_directory(        data_dir,        target_size=(512, 512),        color_mode='rgb',        batch_size=BATCH_SIZE,        class_mode=None,        save_format='jpeg',        seed=GENERATE_SEED,        subset='training'    ),    mask_datagen.flow_from_directory(        label_dir,        target_size=(512, 512),        color_mode='grayscale',        batch_size=BATCH_SIZE,        class_mode=None,        save_format='png',        seed=GENERATE_SEED,        subset='training'    ))

到目前为止,没有出现任何问题。但由于我需要对图像进行一些额外的预处理(例如归一化),而不对掩码进行预处理,我创建了另一个imagegen_args字典,并在实例化ImageDataGenerator时使用它作为参数。

maskgen_args = dict(    rotation_range=90,    validation_split=VALIDATION_SPLIT)imagegen_args = dict(    samplewise_center=True,    samplewise_std_normalization=True,    channel_shift_range=10,    brightness_range=(0.7, 1.3),    **maskgen_args)image_datagen = ImageDataGenerator(**imagegen_args)mask_datagen = ImageDataGenerator(**maskgen_args)

当我检查training_data_generator的输出时,出现了问题:似乎图像和掩码是分别生成的:它们确实有随机旋转,但旋转的角度不同,不像之前那样一致。这里有一个食物图像及其掩码的示例。

Inconsistency

我检查了image_datagenmask_datagenid,在两种情况下它们的id都是不同的。我想知道为什么在第一种情况下它们可以以相同的随机角度旋转图像和掩码,但在第二种情况下却不行?当我确实需要为image_datagen提供额外的参数时,我应该做些什么才能让它们像第一种情况那样表现?


回答:

当你设置

channel_shift_range=10,brightness_range=(0.7, 1.3)

这会修改此生成器的RNG,使图像的RNG和掩码的RNG不再同步。

我建议你在KP新API发布之前使用自定义的Sequence来完成这项任务。(参见 https://github.com/keras-team/governance/blob/master/rfcs/20190729-keras-preprocessing-redesign.md

关于自定义Sequence的示例,我在这里提供了一个例子:https://dref360.github.io/deterministic-da/

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注