我创建了一个简单的HDF5读取器类,以避免将整个数据集加载到内存中。我使用了序列类来实现这一点,但我不能确定on_epoch_end()函数是否会正确触发。
我在其中放置了一个简单的打印语句,但它从未出现!所以我认为我的代码中可能有问题:
class HDF5Generator(tf.keras.utils.Sequence): def __init__(self, hdf5_file, shuffle=True): print("GENERATED") self.hdf5 = h5py.File(hdf5_file, 'r') self.shuffle = shuffle self.indices = list(range(0, len(self.hdf5["samples"]))) random.Random().shuffle(self.indices) def __len__(self): return len(self.hdf5["samples"]) def __getitem__(self, idx): return self.hdf5["samples"][self.indices[idx]], self.hdf5["labels"][self.indices[idx]] def on_epoch_end(self): print("RE-SHUFFLE") random.Random().shuffle(self.indices)
我这样使用它:
d = tf.data.Dataset.from_generator(HD5FGenerator, args=[dataset], output_signature=(...))d = d.batch(batch_size).prefetch(tf.data.AUTOTUNE).cache()...model.fit(d, epochs=epochs)
在控制台中显示了epoch计数器、进度条和字符串”GENERATED”,但从未显示”RE-SHUFFLE”
我错过了什么?
回答:
由于似乎是一个TensorFlow的bug,我找到了一个触发生成器on_epoch_end()
的解决方法。
class CallbackOnEpochEnd(Callback): def __init__(self, generator): super(CallbackOnEpochEnd, self).__init__() self.generator = generator def on_epoch_end(self, epoch, logs=None): self.generator.on_epoch_end()[...]generator = HDF5Generator()d = tf.data.Dataset.from_generator(lambda: generator, output_signature=(tf.TensorSpec(shape=(5,20)), tf.TensorSpec(shape=(1,))))[...]on_epoch_end_callback = CallbackOnEpochEnd(generator)[...]model.fit(d, epochs=5, callbacks=[on_epoch_end_callback])
使用这种方法后,每个epoch结束后控制台中都会显示”RE-SHUFFLE”!