我在Tensorflow中完成了一次简单的线性回归。我如何知道回归的系数是什么?我已经阅读了文档,但我在任何地方都找不到相关信息! (https://www.tensorflow.org/api_docs/python/tf/estimator/LinearRegressor)
编辑 代码示例
回答:
编辑:正如@某人指出的那样,在这个回答发布后,有了一些变化。现在估计器方法有了get_variable_names
和get_variable_value
,估计器的权重似乎不再自动添加到tf.GraphKeys.MODEL_VARIABLES
中了。
估计器被设计为基本上像一个黑盒子,因此没有直接的API来检索权重。即使像你这样,你是定义模型的人(相对于使用预先存在的模型),你也无法直接从估计器对象中访问参数。
尽管如此,你仍然可以通过其他方式检索变量。如果你知道变量的名称,一个选项是简单地从图形对象中获取它们,使用get_operation_by_name
或get_tensor_by_name
。一个更实用且更通用的选项是使用集合。无论是当你调用tf.get_variable
时,还是在之后调用tf.add_to_collection
时,你都可以将模型变量放到一个共同的集合名称下,以便以后检索。如果你查看tf.estimator.LinearRegressor
是如何实际构建的(在这个模块中搜索linear_model
函数),所有模型变量都被添加到了tf.GraphKeys.GLOBAL_VARIABLES
和tf.GraphKeys.MODEL_VARIABLES
中。这(据我推测,我没有真正检查过)对所有可用的预制估计器都是通用的,所以通常在使用这些估计器中的一个时,你应该可以简单地这样做:
model_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)
在这种情况下,最好使用tf.GraphKeys.MODEL_VARIABLES
而不是tf.GraphKeys.GLOBAL_VARIABLES
,后者具有更通用的目的,并且可能包含其他不相关的变量。