我是机器学习的新手,正在尝试解决我的代码中的一个问题。
我使用的训练数据集保存在一个tf.record
文件中,因为数据量太大无法一次性存入内存,所以我使用了一个迭代器来处理训练集。问题是,验证集也太大无法放入内存(至少对于我只有2GB显存的笔记本电脑来说),所以它也被保存为tf.record
格式,我觉得不能用同样的迭代器技巧来处理它。那我该用什么方法呢?
代码
#reading training and validation datasetdef read_tfrecord(example): tfrecord_format = ( { "x": tf.io.FixedLenFeature([], tf.string), "y": tf.io.FixedLenFeature([], tf.string), } ) example = tf.io.parse_single_example(example, tfrecord_format) x = tf.io.parse_tensor(example['x'], out_type=tf.float32) y = tf.io.parse_tensor(example['y'], out_type=tf.double) return x,yfilename = "train.tfrecord"training_dataset = (tf.data.TFRecordDataset(filename).map(read_tfrecord))iterator = training_dataset.repeat().prefetch(10).as_numpy_iterator()filename = "validation.tfrecord"validation_dataset = (tf.data.TFRecordDataset(filename).map(read_tfrecord))val_iterator = validation_dataset.repeat().prefetch(10).as_numpy_iterator()
然后我以这种方式调用fit方法
model.fit(iterator, validation_data=(val_iterator), epochs=35, verbose=1)
但是程序无法完成第一个epoch,它卡住并且永远不会结束
回答:
找到了使用生成器的解决方案,我会发布代码
#generatordef generator(self, dataset, batch_size): ds = dataset.repeat().prefetch(tf.data.AUTOTUNE) iterator = iter(ds) x, y = iterator.get_next() while True: yield x, y#reading training and validation datasetdef read_tfrecord(example): tfrecord_format = ( { "x": tf.io.FixedLenFeature([], tf.string), "y": tf.io.FixedLenFeature([], tf.string), } ) example = tf.io.parse_single_example(example, tfrecord_format) x = tf.io.parse_tensor(example['x'], out_type=tf.float32) y = tf.io.parse_tensor(example['y'], out_type=tf.double) return x,yfilename = "train.tfrecord"training_dataset = tf.data.TFRecordDataset(filename).map(read_tfrecord)train_ds = generator(training_dataset, batch_size)filename = "validation.tfrecord"validation_dataset = (tf.data.TFRecordDataset(filename).map(read_tfrecord))valid_ds = generator(validation_dataset, batch_size)kwargs['validation_data'] = (valid_ds)#get your training step with something like thistraining_steps = x.shape[0]//batch_sizevalidation_steps = x_val.shape[0]//batch_sizemodel.fit(train_ds, steps_per_epoch = training_steps, validation_steps=validation_steps, **kwargs)