我试图使用fit_generator
和自定义生成器来读取内存中无法容纳的数据。我有125万行数据需要训练,因此我让生成器每次yield 50,000行。fit_generator
有25个steps_per_epoch
,我以为这会在每个epoch中处理125万行数据。我添加了一个打印语句来查看处理的偏移量,发现它在进入第二个epoch的几步后就超过了最大值。该文件总共有175万条记录,一旦超过10步,它在create_feature_matrix
调用中会出现索引错误(因为它没有引入任何行)。
def get_next_data_batch(): import gc nrows = 50000 skiprows = 0 while True: d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0) print(skiprows) x,y = create_feature_matrix(d) yield x,y skiprows = skiprows + nrows gc.collect()get_data = get_next_data_batch()... set up a Keras NN ...model.fit_generator(get_next_data_batch(), epochs=100,steps_per_epoch=25,verbose=1,workers=4,callbacks=callbacks_list)
我使用fit_generator
的方式是否有误,或者我的自定义生成器是否需要做一些修改才能正常工作?
回答:
不是的 – fit_generator
不会重置生成器,它只是继续调用它。为了实现你想要的行为,你可以尝试以下方法:
def get_next_data_batch(nb_of_calls_before_reset=25): import gc nrows = 50000 skiprows = 0 nb_calls = 0 while True: d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0) print(skiprows) x,y = create_feature_matrix(d) yield x,y nb_calls += 1 if nb_calls == nb_of_calls_before_reset: skiprows = 0 else: skiprows = skiprows + nrows gc.collect()