Keras中的fit_generator是否应该在每个epoch后重置生成器?

我试图使用fit_generator和自定义生成器来读取内存中无法容纳的数据。我有125万行数据需要训练,因此我让生成器每次yield 50,000行。fit_generator有25个steps_per_epoch,我以为这会在每个epoch中处理125万行数据。我添加了一个打印语句来查看处理的偏移量,发现它在进入第二个epoch的几步后就超过了最大值。该文件总共有175万条记录,一旦超过10步,它在create_feature_matrix调用中会出现索引错误(因为它没有引入任何行)。

def get_next_data_batch():    import gc    nrows = 50000    skiprows = 0    while True:        d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)        print(skiprows)        x,y = create_feature_matrix(d)        yield x,y        skiprows = skiprows + nrows        gc.collect()get_data = get_next_data_batch()... set up a Keras NN ...model.fit_generator(get_next_data_batch(), epochs=100,steps_per_epoch=25,verbose=1,workers=4,callbacks=callbacks_list)

我使用fit_generator的方式是否有误,或者我的自定义生成器是否需要做一些修改才能正常工作?


回答:

不是的 – fit_generator不会重置生成器,它只是继续调用它。为了实现你想要的行为,你可以尝试以下方法:

def get_next_data_batch(nb_of_calls_before_reset=25):    import gc    nrows = 50000    skiprows = 0    nb_calls = 0    while True:        d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)        print(skiprows)        x,y = create_feature_matrix(d)        yield x,y        nb_calls += 1        if nb_calls == nb_of_calls_before_reset:            skiprows = 0        else:            skiprows = skiprows + nrows        gc.collect()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注