我手动构建了一个数据生成器,每次调用时会生成一个包含[input, target]
的元组。我设置了生成器在每个epoch开始时对训练样本进行洗牌。然后我使用fit_generator
来调用我的generator
,但对该函数中的”shuffle”参数感到困惑:
fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
根据Keras的API文档:
shuffle:是否在每个epoch开始时对批次顺序进行洗牌。仅在使用keras.utils.Sequence实例时使用
我以为”shuffle”应该是generator
的工作。当我的自定义generator
决定每次迭代输出哪个批次时,如何能洗牌批次的顺序呢?
回答:
正如你引用的文档所说,shuffle参数只对实现了keras.utils.Sequence的生成器有意义。
如果你使用的是“简单”生成器(例如keras.preprocessing.image.ImageDataGenerator,或者你自己的非Sequence自定义生成器),那么该生成器实现了一种返回单个批次的方法(使用yield – 你可以在这个问题中了解更多)。因此,只有生成器本身控制返回哪个批次。
keras.utils.Sequence的引入是为了支持多进程处理:
Sequence提供了一种更安全的多进程处理方式。这种结构保证了网络在每个epoch中只会对每个样本训练一次,而使用生成器时则不然。
为此,你需要实现一个通过批次索引返回批次的方法(这允许多个工作者的同步):__getitem__(self, idx)
。如果你启用了shuffle参数,__getitem__
方法将以随机顺序被调用。
然而,你也可以将其设置为false,并通过实现on_epoch_end
方法来自己进行洗牌。