在 “model.fit_generator” 中,参数 “max_q_size” 的用途是什么?

我构建了一个简单的生成器,它会生成一个 tuple(inputs, targets),其中 inputstargets 列表中只包含单个项目。基本上,它是在逐个样本地爬取数据集。

我将这个生成器传递给:

  model.fit_generator(my_generator(),                      nb_epoch=10,                      samples_per_epoch=1,                      max_q_size=1  # 默认值为10                      )

我理解的是:

  • nb_epoch 是训练批次运行的次数
  • samples_per_epoch 是每个epoch训练的样本数量

但是 max_q_size 是做什么用的,为什么它的默认值是10?我以为使用生成器的目的是将数据集分成合理的块,为什么还需要额外的队列呢?


回答:

这只是定义了内部训练队列的最大尺寸,用于从生成器中“预缓存”您的样本。在生成队列时使用它。

def generator_queue(generator, max_q_size=10,                    wait_time=0.05, nb_worker=1):    '''从数据生成器构建一个线程队列。    用于 `fit_generator`, `evaluate_generator`, `predict_generator`。    '''    q = queue.Queue()    _stop = threading.Event()    def data_generator_task():        while not _stop.is_set():            try:                if q.qsize() < max_q_size:                    try:                        generator_output = next(generator)                    except ValueError:                        continue                    q.put(generator_output)                else:                    time.sleep(wait_time)            except Exception:                _stop.set()                raise    generator_threads = [threading.Thread(target=data_generator_task)                         for _ in range(nb_worker)]    for thread in generator_threads:        thread.daemon = True        thread.start()    return q, _stop

换句话说,您有一个线程直接从您的生成器中填充队列,直到达到给定的最大容量,而(例如)训练程序会消耗这些元素(有时需要等待完成)

 while samples_seen < samples_per_epoch:     generator_output = None     while not _stop.is_set():         if not data_gen_queue.empty():             generator_output = data_gen_queue.get()             break         else:             time.sleep(wait_time)

为什么默认值是10?没有特别的原因,就像大多数默认值一样——它只是看起来合理,但您也可以使用不同的值。

这样的构造表明,作者考虑了执行时间较长的数据生成器。例如,考虑在生成器调用中通过网络下载数据——那么预缓存一些接下来的批次,并并行下载后续批次以提高效率和对网络错误等的鲁棒性是有意义的。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注