我构建了一个简单的生成器,它会生成一个 tuple(inputs, targets)
,其中 inputs
和 targets
列表中只包含单个项目。基本上,它是在逐个样本地爬取数据集。
我将这个生成器传递给:
model.fit_generator(my_generator(), nb_epoch=10, samples_per_epoch=1, max_q_size=1 # 默认值为10 )
我理解的是:
nb_epoch
是训练批次运行的次数samples_per_epoch
是每个epoch训练的样本数量
但是 max_q_size
是做什么用的,为什么它的默认值是10?我以为使用生成器的目的是将数据集分成合理的块,为什么还需要额外的队列呢?
回答:
这只是定义了内部训练队列的最大尺寸,用于从生成器中“预缓存”您的样本。在生成队列时使用它。
def generator_queue(generator, max_q_size=10, wait_time=0.05, nb_worker=1): '''从数据生成器构建一个线程队列。 用于 `fit_generator`, `evaluate_generator`, `predict_generator`。 ''' q = queue.Queue() _stop = threading.Event() def data_generator_task(): while not _stop.is_set(): try: if q.qsize() < max_q_size: try: generator_output = next(generator) except ValueError: continue q.put(generator_output) else: time.sleep(wait_time) except Exception: _stop.set() raise generator_threads = [threading.Thread(target=data_generator_task) for _ in range(nb_worker)] for thread in generator_threads: thread.daemon = True thread.start() return q, _stop
换句话说,您有一个线程直接从您的生成器中填充队列,直到达到给定的最大容量,而(例如)训练程序会消耗这些元素(有时需要等待完成)
while samples_seen < samples_per_epoch: generator_output = None while not _stop.is_set(): if not data_gen_queue.empty(): generator_output = data_gen_queue.get() break else: time.sleep(wait_time)
为什么默认值是10?没有特别的原因,就像大多数默认值一样——它只是看起来合理,但您也可以使用不同的值。
这样的构造表明,作者考虑了执行时间较长的数据生成器。例如,考虑在生成器调用中通过网络下载数据——那么预缓存一些接下来的批次,并并行下载后续批次以提高效率和对网络错误等的鲁棒性是有意义的。