在Keras的fit_generator中,”shuffle”参数有什么作用?

我手动构建了一个数据生成器,每次调用时会生成一个包含[input, target]的元组。我设置了生成器在每个epoch开始时对训练样本进行洗牌。然后我使用fit_generator来调用我的generator,但对该函数中的”shuffle”参数感到困惑:

fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

根据Keras的API文档:

shuffle:是否在每个epoch开始时对批次顺序进行洗牌。仅在使用keras.utils.Sequence实例时使用

我以为”shuffle”应该是generator的工作。当我的自定义generator决定每次迭代输出哪个批次时,如何能洗牌批次的顺序呢?


回答:

正如你引用的文档所说,shuffle参数只对实现了keras.utils.Sequence的生成器有意义。

如果你使用的是“简单”生成器(例如keras.preprocessing.image.ImageDataGenerator,或者你自己的非Sequence自定义生成器),那么该生成器实现了一种返回单个批次的方法(使用yield – 你可以在这个问题中了解更多)。因此,只有生成器本身控制返回哪个批次。

keras.utils.Sequence的引入是为了支持多进程处理:

Sequence提供了一种更安全的多进程处理方式。这种结构保证了网络在每个epoch中只会对每个样本训练一次,而使用生成器时则不然。

为此,你需要实现一个通过批次索引返回批次的方法(这允许多个工作者的同步):__getitem__(self, idx)。如果你启用了shuffle参数,__getitem__方法将以随机顺序被调用。

然而,你也可以将其设置为false,并通过实现on_epoch_end方法来自己进行洗牌。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注