Keras fit_generator 每次训练一个样本,而我从生成器中产出的样本却不止一个

我在使用Keras训练模型时,尝试了’fit’和’fit_generator’两个函数。我不明白为什么它们的性能差异如此大,可能是我做错了什么。这是第一次编写批量生成器代码。

在批量大小为10的情况下,我观察到使用函数时-

fit:训练速度较快(每轮大约3分钟),详细信息中的计数以批量大小的倍数增加(这里是10)
样本-80/7632 […………………………] – ETA: 4:31 – 损失: 2.2072 – 准确率: 0.4375

fit_generator:训练速度较慢(每轮10分钟),详细信息中的计数每次增加1(不等于批量大小)
样本-37/7632 […………………………] – ETA: 42:25 – 损失: 2.1845 – 准确率: 0.3676

如你所见,对于相同的数据集,fit_generator的预计完成时间(ETA)过高。而且fit_generator每次增加1,而fit是以10的倍数增加。

生成器:

def batch_generator(X ,y, batch_size=10):    from sklearn.utils import shuffle    batch_count = int(len(X) / batch_size)    extra = len(X) - (batch_count * batch_size)    while 1:        #shuffle X and y        X_train, y_train = shuffle(X,y)        #Yeild Batches        for i in range(1, batch_count):            batch_start = (i-1) * batch_size            batch_end = i * batch_size            X_batch = X_train[batch_start: batch_end]            y_batch = y_train[batch_start: batch_end]            yield X_batch, y_batch        #Yeild Remaining Data less than batch size        if(extra > 0):            batch_start = batch_count * batch_size            X_batch = X_train[batch_start: -1]            y_batch = y_train[batch_start: -1]            yield X_batch, y_batch

Fit函数:

model.fit_generator(batch_generator(X, y, 10),                    verbose = 1,                    samples_per_epoch = len(X),                    epochs = 20,                    validation_data = (X_test, y_test),                    callbacks = callbacks_list)

谁能解释这是为什么?


回答:

fit_generator不使用样本,它使用步骤,你使用的是旧版Keras API,带有samples_per_epoch参数,这是不正确的,会产生错误的结果。正确的fit_generator调用应该是这样的:

model.fit_generator(batch_generator(X, y, 10),                    verbose = 1,                    steps_per_epoch = int(len(X) / batch_size),                    epochs = 20,                    validation_data = (X_test, y_test),                    callbacks = callbacks_list)

steps_per_epoch控制在宣布一轮结束之前要使用多少步骤(调用生成器的次数)。它应该设置为总样本数除以批量大小。对于fit_generator,进度条中的索引指的是步骤(批次),而不是样本,因此你不能直接将它们与fit的进度条中的索引进行比较。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注