在 “model.fit_generator” 中,参数 “max_q_size” 的用途是什么?

我构建了一个简单的生成器,它会生成一个 tuple(inputs, targets),其中 inputstargets 列表中只包含单个项目。基本上,它是在逐个样本地爬取数据集。

我将这个生成器传递给:

  model.fit_generator(my_generator(),                      nb_epoch=10,                      samples_per_epoch=1,                      max_q_size=1  # 默认值为10                      )

我理解的是:

  • nb_epoch 是训练批次运行的次数
  • samples_per_epoch 是每个epoch训练的样本数量

但是 max_q_size 是做什么用的,为什么它的默认值是10?我以为使用生成器的目的是将数据集分成合理的块,为什么还需要额外的队列呢?


回答:

这只是定义了内部训练队列的最大尺寸,用于从生成器中“预缓存”您的样本。在生成队列时使用它。

def generator_queue(generator, max_q_size=10,                    wait_time=0.05, nb_worker=1):    '''从数据生成器构建一个线程队列。    用于 `fit_generator`, `evaluate_generator`, `predict_generator`。    '''    q = queue.Queue()    _stop = threading.Event()    def data_generator_task():        while not _stop.is_set():            try:                if q.qsize() < max_q_size:                    try:                        generator_output = next(generator)                    except ValueError:                        continue                    q.put(generator_output)                else:                    time.sleep(wait_time)            except Exception:                _stop.set()                raise    generator_threads = [threading.Thread(target=data_generator_task)                         for _ in range(nb_worker)]    for thread in generator_threads:        thread.daemon = True        thread.start()    return q, _stop

换句话说,您有一个线程直接从您的生成器中填充队列,直到达到给定的最大容量,而(例如)训练程序会消耗这些元素(有时需要等待完成)

 while samples_seen < samples_per_epoch:     generator_output = None     while not _stop.is_set():         if not data_gen_queue.empty():             generator_output = data_gen_queue.get()             break         else:             time.sleep(wait_time)

为什么默认值是10?没有特别的原因,就像大多数默认值一样——它只是看起来合理,但您也可以使用不同的值。

这样的构造表明,作者考虑了执行时间较长的数据生成器。例如,考虑在生成器调用中通过网络下载数据——那么预缓存一些接下来的批次,并并行下载后续批次以提高效率和对网络错误等的鲁棒性是有意义的。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注