在使用Keras的机器学习教程中,训练机器学习模型的代码通常是一行简单的代码。
model.fit(X_train, Y_train, nb_epoch=5, batch_size = 128, verbose=1, validation_split=0.1)
当训练数据X_train
和Y_train
较小时,这看起来很容易。X_train
和Y_train
是numpy的多维数组。在实际情况中,训练数据可能会达到几千兆字节,这可能太大而无法装入计算机的RAM中。
当训练数据过于庞大时,如何将数据传递给model.fit()
呢?
回答:
在Keras中有一个简单的解决方案。你可以简单地使用Python生成器,让你的数据延迟加载。如果你有图像数据,你还可以使用ImageDataGenerator。
def generate_data(x, y, batch_size): while True: batch = [] for b in range(batch_size): batch.append(myDataSlice) yield np.array(batch )model.fit_generator(generator=generate_data(x, y, batch_size),steps_per_epoch=num_batches, validation_data=list_batch_generator(x_val, y_val, batch_size), validation_steps=num_batches_test)