如何在Keras的datagen.flow中加入验证数据?

这是我之前帖子中遇到的问题的延伸。

我在Keras中应用以下代码进行数据增强(目前我不想使用model.fit_generator,所以我手动使用datagen.flow进行循环)。

datagen = ImageDataGenerator(    featurewise_center=False,    featurewise_std_normalization=False,    rotation_range=20,    width_shift_range=0.2,    height_shift_range=0.2,    horizontal_flip=True)# 计算特征标准化所需的量# (如果应用ZCA白化,则包括标准差、均值和主成分)datagen.fit(x_train)# model.fit_generator的替代方案for e in range(epochs):    print('Epoch', e)    batches = 0    for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):        model.fit(x_batch, y_batch)        batches += 1        if batches >= len(x_train) / 32:            # 我们需要手动中断循环,因为            # 生成器会无限循环            break

我想在我的model.fit循环中加入验证数据。例如,我想用类似model.fit(X_batch,y_batch, validation_data=(x_val, y_val))的代码替换model.fit(X_batch,y_batch),在for循环中使用。

我对如何在for循环中使用datagen.flow加入这个验证组件感到有些困惑。欢迎任何关于我应该如何进行的见解。


回答:

我假设你已经将数据分成了训练集和验证集。如果没有,你需要这样做以便使用下面的建议。

你可以使用验证数据创建第二个数据生成器,然后简单地同时迭代这个生成器和训练数据生成器。我还在下面的代码中添加了进一步的帮助作为注释。

这是你的代码,经过修改以实现这一点,但你可能还想再做一些调整:

# 与你的代码无变化tr_datagen = ImageDataGenerator(    featurewise_center=False,    featurewise_std_normalization=False,    rotation_range=20,    width_shift_range=0.2,    height_shift_range=0.2,    horizontal_flip=True)# 为验证数据创建新的生成器val_datagen = ImageDataGenerator()    # 对验证数据不进行增强# 计算特征标准化所需的量# (如果应用ZCA白化,则包括标准差、均值和主成分)tr_datagen.fit(x_train)    # 如果不进行标准化或白化,可以省略此步骤 val_datagen.fit(x_val)     # 如果不进行标准化或白化,可以省略此步骤# model.fit_generator的替代方案for e in range(epochs):    print('Epoch', e)    batches = 0    # 结合两个生成器,在Python 3中使用zip()    for (x_batch, y_batch), (val_x, val_y) in zip(                                 tr_datagen.flow(x_train, y_train, batch_size=32),                                 val_datagen.flow(x_val, y_val, batch_size=32)):        model.fit(x_batch, y_batch, validation_Data=(val_x, val_y))        batches += 1        if batches >= len(x_train) / 32:            # 我们需要手动中断循环,因为            # 生成器会无限循环            break

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注