我想在训练过程中将变量和偏置张量保存为检查点。我使用了tf.contrib.layers中的fully_connected()来实现几个全连接层。为此,我需要提取这些全连接层的变量和偏置张量。如何做到这一点?
回答:
需要注意的是:
- 没有必要仅仅为了保存它们而提取权重和偏置。对于tf.layers或tf.contrib.layers,如果trainable被设置为
True
,权重和偏置会被添加到GraphKeys.TRAINABLE_VARIABLES
中,这是GraphKeys.GLOBAL_VARIABLES
的一个子集。因此,如果你使用saver = tf.train.Saver(var_list=tf.global_variables())
和saver.save(sess, save_path, global_step)
在某个时刻,权重和偏置将会被保存。 - 在你确实需要提取变量的情况下,一种方法是使用
tf.get_variable
或tf.get_default_graph().get_tensor_by_name
,并使用正确的变量名称,如另一个答案中提到的。 - 你可能已经注意到TensorFlow中的类如
tf.layer.Dense
和tf.layers.Conv2D
。一旦构建,它们有weights
/variables
方法,可以返回权重和偏置张量。