我有大量的numpy数组,无法一次性装入RAM中。假设有数百万个这样的数组:
np.arange(10)
- 我想将它们分块保存到文件系统中的一个文件中。
- 我想从文件中读取这些数组,并使用
model.fit_generator
将它们输入到我的keras模型中。
我了解过dask,它可以处理无法装入内存的大数据,但未能实现我的目标。
回答:
使用pickle将文件写入磁盘:
pickle.dump((x, y), open(file, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
然后创建测试和训练文件列表,并创建一个生成器:
def raw_generator(files): while 1: for file_num, file in enumerate(files): try: x, y = pickle.load(open(file, 'rb')) batches = int(np.ceil(len(y) / batch_size)) for i in range(0, batches): end = min(len(x), i * batch_size + batch_size) yield x[i * batch_size:end], y[i * batch_size:end] except EOFError: print("error" + file)train_gen = preprocessing.generator(training_files)test_gen = preprocessing.generator(test_files)
最后调用fit_generator:
history = model.fit_generator( generator=train_gen, steps_per_epoch= (len(training_files)*data_per_file)/batch_size, epochs=epochs validation_data=test_gen, validation_steps=(len(test_files)*data_per_file)/batch_size, use_multiprocessing=False, max_queue_size=10, workers=1, verbose=1)