Keras中的fit_generator是否应该在每个epoch后重置生成器?

我试图使用fit_generator和自定义生成器来读取内存中无法容纳的数据。我有125万行数据需要训练,因此我让生成器每次yield 50,000行。fit_generator有25个steps_per_epoch,我以为这会在每个epoch中处理125万行数据。我添加了一个打印语句来查看处理的偏移量,发现它在进入第二个epoch的几步后就超过了最大值。该文件总共有175万条记录,一旦超过10步,它在create_feature_matrix调用中会出现索引错误(因为它没有引入任何行)。

def get_next_data_batch():    import gc    nrows = 50000    skiprows = 0    while True:        d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)        print(skiprows)        x,y = create_feature_matrix(d)        yield x,y        skiprows = skiprows + nrows        gc.collect()get_data = get_next_data_batch()... set up a Keras NN ...model.fit_generator(get_next_data_batch(), epochs=100,steps_per_epoch=25,verbose=1,workers=4,callbacks=callbacks_list)

我使用fit_generator的方式是否有误,或者我的自定义生成器是否需要做一些修改才能正常工作?


回答:

不是的 – fit_generator不会重置生成器,它只是继续调用它。为了实现你想要的行为,你可以尝试以下方法:

def get_next_data_batch(nb_of_calls_before_reset=25):    import gc    nrows = 50000    skiprows = 0    nb_calls = 0    while True:        d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)        print(skiprows)        x,y = create_feature_matrix(d)        yield x,y        nb_calls += 1        if nb_calls == nb_of_calls_before_reset:            skiprows = 0        else:            skiprows = skiprows + nrows        gc.collect()

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注