使用Tensorflow创建数据集时,为什么需要返回一个函数对象?

我是机器学习的新手,正在尝试使用Tensorflow文档中的教程从这里创建一个机器学习模型,但我无法理解代码的这一部分

 def make_input_fn(data_df, label_df, num_epochs=10, shuffle=True, batch_size=32):  def input_function():          # 内部函数,将被返回    ds = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))     # 使用数据及其标签创建tf.data.Dataset对象    if shuffle:      ds = ds.shuffle(1000)   # 随机化数据顺序    ds = ds.batch(batch_size).repeat(num_epochs)      # 将数据集分成32的批次,并重复处理指定的轮次    return ds  # 返回数据集的一个批次  return input_function  # 返回一个函数对象以供使用

然后将函数的输出存储在一个变量中

train_input_fn = make_input_fn(dftrain, y_train)

最后使用数据集训练模型

linear_est.train(train_input_fn)

我无法理解为什么在make_input_function中只返回内部函数的函数名而不是直接返回我们的数据集并将其传递给模型进行训练。

我是Python初学者,刚开始学习机器学习,我找不到我问题的正确答案,所以如果有人能以适合初学者的方式解释,我将非常感激。


回答:

我无法理解为什么在make_input_function中只返回内部函数的函数名而不是直接返回我们的数据集并将其传递给模型进行训练。

在Python编程中,这被称为柯里化,它用于将多参数函数转换为单参数函数,通过逐步嵌套函数参数进行评估。柯里化还将一个参数与另一个参数结合,形成执行时的相对模式。

在Tensorflow中,根据文档(https://www.tensorflow.org/api_docs/python/tf/estimator/LinearClassifier#train)。

train(    input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None)

估计器的train方法期望一个参数input_fn。原因是每次调用Estimator.train()时,它都会通过调用input_fnmodel_fn创建一个新的图,并将它们连接在一起。如果你提供的是张量或数据集,会导致不同的错误。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注