我打算在MongoDB中存储大约50万张图片,并使用这个数据集来训练一个基于Keras的神经网络。
我选择使用数据库而不是将图片单独存储在磁盘上,以提高数据加载速度。加载单独的图片进行训练大约需要1.5小时——这太长了。
总数据量约为1TB,显然无法全部装入RAM,因此最好的方法是使用模型类的fit_generator
方法,或者使用keras.utils.Sequence
。
我更倾向于使用Sequence
。正如文档中所说,
使用
keras.utils.Sequence
可以保证顺序,并且在使用use_multiprocessing=True
时,每个epoch内每个输入只使用一次。
对Sequence
进行子类化(以及编写生成器),包装数据库连接,需要实现__len__
和__get_item__
方法。
__len__
方法没有问题,但__get_item__
方法让我感到困惑,因为作为作者,我需要负责保证顺序和单次使用。
主要的困惑在于MongoDB中的对象编号问题。
Sequence
子类的__get_item__
方法必须返回完整的批次,并且在给定相同的输入索引时,必须重复返回相同的批次。
数据库为集合中的文档分配唯一的标识符(在pymongo中为_id
或ObjectID
),我想用它们来按批次收集图片,但不知道该如何操作。
(max_object_id - min_object_id) / n_per_batch
这种计算方式是不可能的,因为ObjectID
的创建方式使得它们之间无法进行减法运算。
我想得到一些建议,关于如何将_ids
分割成不重叠的区间,并对这些区间进行编号。
我有一个针对LMDB的类似问题的旧解决方案,是从Github上借用的,它创建了一个额外的pickle文件来存储键的数组。我不喜欢这种方法。
回答:
实验表明,简单的方法效果很好。我在__init__
中使用projection={'_id': True}
获取ObjectIDs列表,然后检索相应的文档。
代码在这里。