如何将MongoDB与Keras中的Sequence类或生成器一起使用?

我打算在MongoDB中存储大约50万张图片,并使用这个数据集来训练一个基于Keras的神经网络。

我选择使用数据库而不是将图片单独存储在磁盘上,以提高数据加载速度。加载单独的图片进行训练大约需要1.5小时——这太长了。

总数据量约为1TB,显然无法全部装入RAM,因此最好的方法是使用模型类的fit_generator方法,或者使用keras.utils.Sequence

我更倾向于使用Sequence。正如文档中所说,

使用keras.utils.Sequence可以保证顺序,并且在使用use_multiprocessing=True时,每个epoch内每个输入只使用一次。

Sequence进行子类化(以及编写生成器),包装数据库连接,需要实现__len____get_item__方法。

__len__方法没有问题,但__get_item__方法让我感到困惑,因为作为作者,我需要负责保证顺序和单次使用。

主要的困惑在于MongoDB中的对象编号问题。

Sequence子类的__get_item__方法必须返回完整的批次,并且在给定相同的输入索引时,必须重复返回相同的批次。

数据库为集合中的文档分配唯一的标识符(在pymongo中为_idObjectID),我想用它们来按批次收集图片,但不知道该如何操作。

(max_object_id - min_object_id) / n_per_batch这种计算方式是不可能的,因为ObjectID的创建方式使得它们之间无法进行减法运算。

我想得到一些建议,关于如何将_ids分割成不重叠的区间,并对这些区间进行编号。

我有一个针对LMDB的类似问题的旧解决方案,是从Github上借用的,它创建了一个额外的pickle文件来存储键的数组。我不喜欢这种方法。


回答:

实验表明,简单的方法效果很好。我在__init__中使用projection={'_id': True}获取ObjectIDs列表,然后检索相应的文档。

代码在这里

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注