宽深TensorFlow教程:为什么只使用评估数据集的一个批次进行测试?

我在尝试理解TensorFlow的宽深学习教程。人口普查收入数据集有两个用于验证的文件:adult.data和adult.test。在一定数量的轮次后,它会打印一个评估(你可以在这里看到完整的代码:https://github.com/tensorflow/models/blob/master/official/wide_deep/wide_deep.py)。它使用“input_fn”来从csv文件中读取输入信息。它被用来读取两个文件,adult.data和adult.test。

def input_fn(data_file, num_epochs, shuffle, batch_size):  """为Estimator生成一个输入函数。"""  assert tf.gfile.Exists(data_file), (      '%s not found. Please make sure you have either run data_download.py or '      'set both arguments --train_data and --test_data.' % data_file)  def parse_csv(value):    print('Parsing', data_file)    columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)    features = dict(zip(_CSV_COLUMNS, columns))    labels = features.pop('income_bracket')    return features, tf.equal(labels, '>50K')  # 使用Dataset API从输入文件中提取行。  dataset = tf.data.TextLineDataset(data_file)  if shuffle:    dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])  dataset = dataset.map(parse_csv, num_parallel_calls=5)  # 我们在洗牌后调用repeat,而不是之前,这样可以防止不同的轮次混合在一起。  dataset = dataset.repeat(num_epochs)  dataset = dataset.batch(batch_size)  return dataset

它构建了一个估计器,DNNLinearCombinedClassifier,并像这样评估和打印准确率:

...results = model.evaluate(input_fn=lambda: input_fn(    FLAGS.test_data, 1, False, FLAGS.batch_size))# 显示评估指标print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)print('-' * 60)for key in sorted(results):  print('%s: %s' % (key, results[key]))

我理解你应该用批次来训练你的网络。我的问题是,为什么他们会用一个批次来评估模型?他们不应该使用整个评估数据集吗?该数据集有16281个验证值,它不应该像这样调用model.evaluate吗?:

_NUM_EXAMPLES = {  'train': 32561,  'validation': 16281,}...results = model.evaluate(input_fn=lambda: input_fn(    FLAGS.test_data, 1, False, _NUM_EXAMPLES['validation']))

使用整个验证数据集是错误的吗?


回答:

训练和测试都需要使用数据的迷你批次,因为如果不这样做,可能会导致内存溢出错误(OOM)。你说的没错,问题在训练中更为关键,因为反向传播实际上会使内存消耗翻倍。但这并不意味着在推理中不可能出现OOM。

来自我经验的例子:

… 我相信还有很多我没见过的例子。根据你的资源,16281可能足够小,可以装进一个批次,但总的来说,在推理中以批次迭代并为这种批次大小设置一个单独的设置是完全有意义的,例如,如果模型将在另一台资源更少的机器上运行。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注