我正在使用 tf.estimator
API 来训练模型。
据我所知,model_fn
定义了计算图,根据 mode
返回不同的 tf.estimator.EstimatorSpec
。
在 mode==tf.estimator.ModeKeys.TRAIN
时,可以指定一个 train_op
,在每次训练迭代时调用,该操作会改变 tf.Variable
的 trainable
实例,以优化某个损失函数。
我们称这个 train_op 为 optimizer
,变量为 A
和 B
。
为了加速预测和评估,我希望有一个辅助的非可训练 tf.Variable
张量 C
,它完全依赖于已经训练好的变量。这个张量的值因此可以导出。这个张量不影响训练损失。假设我们想要:
C = tf.Variable(tf.matmul(A,B))update_op = tf.assign(C, tf.matmul(A,B))
- 我尝试过的方法:
将 tf.group(optimizer, update_op)
作为 train_op
传递给 EstimatorSpec
效果很好但大大减慢了训练速度,因为 train_op
现在在每次迭代时都会更新 C
。
因为 C
只在评估/预测时需要,在训练结束时调用一次 update_op
就足够了。
是否可以在 tf.estimator.Estimator
训练结束时分配变量?
回答:
一般来说,模型函数的单次迭代并不知道训练是否会在它运行后结束,所以我怀疑这不能直接实现。我看到有两个选项:
-
如果你只在训练后需要辅助变量,你可以使用
tf.estimator.Estimator.get_variable_value
(参见 这里)在训练后提取变量A
和B
的值作为 numpy 数组,然后进行计算得到C
。然而,这样C
就不会成为模型的一部分。 -
使用钩子(参见 这里)。你可以编写一个带有
end
方法的钩子,该方法将在会话结束时(即训练停止时)被调用。你可能需要研究钩子是如何定义/使用的——例如,在 这里 你可以找到 Tensorflow 中大多数“基本”钩子的实现。一个粗略的框架可能看起来像这样:class UpdateHook(SessionRunHook): def __init__(update_variable, other_variables): self.update_op = tf.assign(update_variable, some_fn(other_variables)) def end(session): session.run(self.update_op)
由于钩子需要访问变量,你需要在模型函数内部定义钩子。你可以在
EstimatorSpec
中将这样的钩子传递给训练过程(参见 这里)。我没有测试过这个方法!我不确定你是否可以在钩子内部定义操作。如果不能,希望在模型函数内部定义更新操作并直接传递给钩子会有效。