文档对此并不是100%清晰:
注意:在训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作会被放置在tf.GraphKeys.UPDATE_OPS中,因此它们需要作为train_op的依赖项。例如:
(参见 https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization)
这是否意味着只需要以下操作就能保存moving_mean
和moving_variance
?
def model_fn(features, labels, mode, params): training = mode == tf.estimator.ModeKeys.TRAIN extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) x = tf.reshape(features, [-1, 64, 64, 3]) x = tf.layers.batch_normalization(x, training=training) # ... with tf.control_dependencies(extra_update_ops): train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
换句话说,仅仅使用
with tf.control_dependencies(extra_update_ops):
就能处理保存moving_mean
和moving_variance
吗?
回答:
事实证明,这些值确实可以自动保存。特殊情况是,如果你在将批量归一化操作添加到图形之前获取更新操作集合,那么更新集合将为空。这之前没有被文档化,但现在已经记录了。
使用批量归一化时的注意事项是在调用tf.layers.batch_normalization
之后再调用tf.get_collection(tf.GraphKeys.UPDATE_OPS)
。