在Keras中合并两个DirectoryIterator

假设我有类似下面的代码:

image_data_generator = ImageDataGenerator(rescale=1./255)train_generator = image_data_generator.flow_from_directory(  'my_directory',  target_size=(28, 28),  batch_size=32,  class_mode='categorical')

那么我的train_generator将被my_directory中的数据填充,该目录包含两个子文件夹,将数据分为01两个类别。

假设我还有另一个目录that_directory,也包含01两个类别的数据。我希望用这些额外的数据增强我的train_generator

运行train_generator = image_data_generator.flow_from_directory('that_directory', ...)会移除my_directory中的先前数据。

有没有办法在不改变文件夹结构本身的情况下,将两组数据增强或追加到一个生成器或类似DirectoryIterator的对象中?


回答:

只需在另一个生成器中组合这些生成器,可选择使用不同的增强配置:

idg1 = ImageDataGenerator(**idg1_configs)idg2 = ImageDataGenerator(**idg2_configs)g1 = idg1.flow_from_directory('idg1_dir',...)g2 = idg2.flow_from_directory('idg2_dir',...)def combine_gen(*gens):    while True:        for g in gens:            yield next(g)# ...model.fit_generator(combine_gen(g1, g2), steps_per_epoch=len(g1)+len(g2), ...)

这将交替从g1g2生成批次数据。

请注意,有人可能会建议使用itertools.chain,但在这里不能使用,因为ImageDataGenerators生成器是永无止境的,不断生成数据批次。这对于传递给fit_generator方法的生成器是预期的。从Keras文档中:

…生成器预期会无限循环其数据。当模型看到steps_per_epoch批次时,一个epoch就结束了。

如果没有设置steps_per_epoch,它将默认设置为len(generator),其中generator是你传递给fit_generator方法的生成器。ImageDataGenerator生成器可以提供它们的长度,因此你不需要手动设置steps_per_epoch参数。如果你想对上面的组合生成器使用相同的方法,可以使用以下解决方案:

class CombinedGen():    def __init__(self, *gens):        self.gens = gens    def generate(self):        while True:            for g in self.gens:                yield next(g)    def __len__(self):        return sum([len(g) for g in self.gens])# 使用:cg = CombinedGen(g1, g2)model.fit_generator(cg.generate(), ...) # 不需要设置`steps_per_epoch`

你还可以向CombinedGen类添加__next__和/或__iter__方法,如果你有兴趣直接迭代这个类的对象(而不是迭代cg.generate())。

Related Posts

在使用k近邻算法时,有没有办法获取被使用的“邻居”?

我想找到一种方法来确定在我的knn算法中实际使用了哪些…

Theano在Google Colab上无法启用GPU支持

我在尝试使用Theano库训练一个模型。由于我的电脑内…

准确性评分似乎有误

这里是代码: from sklearn.metrics…

Keras Functional API: “错误检查输入时:期望input_1具有4个维度,但得到形状为(X, Y)的数组”

我在尝试使用Keras的fit_generator来训…

如何使用sklearn.datasets.make_classification在指定范围内生成合成数据?

我想为分类问题创建合成数据。我使用了sklearn.d…

如何处理预测时不在训练集中的标签

已关闭。 此问题与编程或软件开发无关。目前不接受回答。…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注