我有一个很长的数据集,需要分块训练。我读到可以多次调用model.fit,但使用model.train_on_batch会更好。这是真的吗?为什么?
回答:
与其多次使用model.fit,不如使用TensorFlow中的make_csv_dataset函数,并将数据集传递给fit命令。假设你的数据是CSV格式。这个函数的优点是按需加载数据,而不是将所有数据一次性加载到主内存中。
tf.data.experimental.make_csv_dataset( file_pattern, batch_size, label_name=None, select_columns=None, shuffle=True,)
这里的file pattern是一个单一字符串,即文件名或字符串模式,如果你想加载多个文件。参见文档
如果你有图像数据集,你可以使用一种称为flow_from_directory的方法。这也以类似的方式工作,只加载处理所需的图像。
# 这是对图像定义预处理的预处理步骤。train_datagen = ImageDataGenerator( rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)test_datagen = ImageDataGenerator(rescale=1./255)# 这是你创建迭代器的地方。train_generator = train_datagen.flow_from_directory( 'data/train', target_size=(150, 150), batch_size=32, class_mode='binary')validation_generator = test_datagen.flow_from_directory( 'data/validation', target_size=(150, 150), batch_size=32, class_mode='binary')model.fit( train_generator, steps_per_epoch=2000, epochs=50, validation_data=validation_generator, validation_steps=800)
参见文档