有人能解释一下TensorFlow的基础教程吗?

我正在尝试完成TensorFlow第一个教程的第二部分:https://www.tensorflow.org/get_started/get_started

“基本用法”:

import tensorflow as tf# NumPy常用于加载、操作和预处理数据。import numpy as np# 声明特征列表。我们只有一个实值特征。还有许多其他类型更复杂且有用的列。features = [tf.contrib.layers.real_valued_column("x", dimension=1)]# 估计器是调用训练(拟合)和评估(推理)的前端。有许多预定义类型,如线性回归、逻辑回归、线性分类、逻辑分类,以及许多神经网络分类器和回归器。以下代码提供了一个进行线性回归的估计器。estimator = tf.contrib.learn.LinearRegressor(feature_columns=features)# TensorFlow提供了许多辅助方法来读取和设置数据集。这里我们使用两个数据集:一个用于训练,一个用于评估。我们必须告诉函数我们想要多少批数据(num_epochs)以及每批数据的大小。x_train = np.array([1., 2., 3., 4.])y_train = np.array([0., -1., -2., -3.])x_eval = np.array([2., 5., 8., 1.])y_eval = np.array([-1.01, -4.1, -7, 0.])input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x_train}, y_train,                                              batch_size=4,                                              num_epochs=1000)eval_input_fn = tf.contrib.learn.io.numpy_input_fn(    {"x":x_eval}, y_eval, batch_size=4, num_epochs=1000)# 我们可以通过调用方法并传入训练数据集来调用1000次训练步骤。estimator.fit(input_fn=input_fn, steps=1000)# 在这里我们评估模型的表现。train_loss = estimator.evaluate(input_fn=input_fn)eval_loss = estimator.evaluate(input_fn=eval_input_fn)print("train loss: %r"% train_loss)print("eval loss: %r"% eval_loss)

有人能解释一下在这个代码中计算图隐藏在哪里吗?我没有看到任何对tf.Graph()tf.Session()的调用。

features变量是用来做什么的?数据似乎从未进入其中,因为数据提供者是’input_fn’。我如何查看会话和图的实际计算图?

为什么我设置epoch数量的地方有两个?(estimator.fitnumpy_input_fn

如果我有两个不同的估计器,分别是estimator1.fit(..., steps=20)estimator2.fit(..., steps=50),我需要设置num_epochs=70吗?还是num_epochs=max(20,50)?如果input_fn是从fit中调用的,而不是反过来,它如何控制线程数量?


回答:

  1. tf.Graph()tf.Session()在哪里?

TensorFlow会为你创建一个默认的计算图。所以你在上面定义的所有变量和操作都将在默认图上。会话是在估计器函数内部定义的。例如,estimator.fit()会创建一个监控训练会话。

  1. features变量是用来做什么的?

它用于初始化LinearRegressor()模型。线性回归参数是基于特征变量设置的。

  1. input_fnestimator中设置的epoch?

input_fn是一个队列,所以input_fn中的epoch告诉每个输入数据需要被弹出到队列中的次数。fit中的epoch测量每个输入在训练中需要使用的次数。所以如果你的fit中的epoch大于队列中的epoch,训练会在队列停止时停止。

  1. 对于两个不同的估计器的队列大小。

input_fn中的epoch应该大于或等于估计器中epoch的总和。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注