在TensorFlow中读取数据

我正在通过文档学习如何使用TensorFlow。但是,我无法理解以下两个函数。我尝试在文档中搜索过,但没有得到明确的答案。

input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x_train}, y_train,                                          batch_size=4,                                          num_epochs=1000)eval_input_fn = tf.contrib.learn.io.numpy_input_fn(    {"x":x_eval}, y_eval, batch_size=4, num_epochs=1000)

另外,如果您能解释一下函数中的参数是什么,那就太好了。提前感谢!


回答:

我同意文档对这些值的解释不够清楚。

第一个参数是网络输入值的字典。

第二个参数是输出值。

batch_size 是每次用于训练网络的项目数量,当数据量很大时,如果每次用一个值来训练网络,训练速度会太慢,因此使用批次处理。如果 x_train 有16个项目,批次大小为4,那么网络将首先用前4个项目进行训练,然后是接下来的4个,直到使用完所有16个输入。

num_epochs 是网络节点优化尝试之前的前向和后向传递次数。

对于批次大小和轮数,您可以根据希望花费在训练上的时间与期望的性能来调整这些参数。

有关术语的详细描述,请参阅这里

一个示例中,您可以使用 input_fn 来获取特征和目标值,以便传递这些数据进行运行。

例如:

with self.test_session() as session:  input_fn = numpy_io.numpy_input_fn({"x":x_train}, y_train,                                      batch_size=4,                                      num_epochs=1000)  features, target = input_fn()  res = session.run([features, target])

这两个函数做的事情相同,只是 input_fn 用于训练网络,而 eval_input_fn 用于评估,因此唯一的区别在于 x 和 y 值,因为在训练数据上评估性能时,性能将是上限或最佳情况。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注