我在尝试理解TensorFlow的宽深学习教程。人口普查收入数据集有两个用于验证的文件:adult.data和adult.test。在一定数量的轮次后,它会打印一个评估(你可以在这里看到完整的代码:https://github.com/tensorflow/models/blob/master/official/wide_deep/wide_deep.py)。它使用“input_fn”来从csv文件中读取输入信息。它被用来读取两个文件,adult.data和adult.test。
def input_fn(data_file, num_epochs, shuffle, batch_size): """为Estimator生成一个输入函数。""" assert tf.gfile.Exists(data_file), ( '%s not found. Please make sure you have either run data_download.py or ' 'set both arguments --train_data and --test_data.' % data_file) def parse_csv(value): print('Parsing', data_file) columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS) features = dict(zip(_CSV_COLUMNS, columns)) labels = features.pop('income_bracket') return features, tf.equal(labels, '>50K') # 使用Dataset API从输入文件中提取行。 dataset = tf.data.TextLineDataset(data_file) if shuffle: dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train']) dataset = dataset.map(parse_csv, num_parallel_calls=5) # 我们在洗牌后调用repeat,而不是之前,这样可以防止不同的轮次混合在一起。 dataset = dataset.repeat(num_epochs) dataset = dataset.batch(batch_size) return dataset
它构建了一个估计器,DNNLinearCombinedClassifier,并像这样评估和打印准确率:
...results = model.evaluate(input_fn=lambda: input_fn( FLAGS.test_data, 1, False, FLAGS.batch_size))# 显示评估指标print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)print('-' * 60)for key in sorted(results): print('%s: %s' % (key, results[key]))
我理解你应该用批次来训练你的网络。我的问题是,为什么他们会用一个批次来评估模型?他们不应该使用整个评估数据集吗?该数据集有16281个验证值,它不应该像这样调用model.evaluate吗?:
_NUM_EXAMPLES = { 'train': 32561, 'validation': 16281,}...results = model.evaluate(input_fn=lambda: input_fn( FLAGS.test_data, 1, False, _NUM_EXAMPLES['validation']))
使用整个验证数据集是错误的吗?
回答:
训练和测试都需要使用数据的迷你批次,因为如果不这样做,可能会导致内存溢出错误(OOM)。你说的没错,问题在训练中更为关键,因为反向传播实际上会使内存消耗翻倍。但这并不意味着在推理中不可能出现OOM。
来自我经验的例子:
… 我相信还有很多我没见过的例子。根据你的资源,16281
可能足够小,可以装进一个批次,但总的来说,在推理中以批次迭代并为这种批次大小设置一个单独的设置是完全有意义的,例如,如果模型将在另一台资源更少的机器上运行。