我的TensorFlow模型定义如下:
X = tf.placeholder(tf.float32, [None,training_set.shape[1]],name = 'X')Y = tf.placeholder(tf.float32,[None,training_labels.shape[1]], name = 'Y')A1 = tf.contrib.layers.fully_connected(X, num_outputs = 50, activation_fn = tf.nn.relu)A1 = tf.nn.dropout(A1, 0.8)A2 = tf.contrib.layers.fully_connected(A1, num_outputs = 2, activation_fn = None)cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = A2, labels = Y)) global_step = tf.Variable(0, trainable=False)start_learning_rate = 0.001learning_rate = tf.train.exponential_decay(start_learning_rate, global_step, 200, 0.1, True )optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
现在我想保存这个模型,但省略张量Y
(Y
是训练用的标签张量,X
是实际输入)。另外,在使用freeze_graph.py
时提到输出节点时,我应该提到"A2"
还是它以其他名称保存?
回答:
虽然你没有手动定义变量,但上面的代码片段实际上包含了15个可保存的变量。你可以使用TensorFlow的内部函数查看它们:
from tensorflow.python.ops.variables import _all_saveable_objectsfor obj in _all_saveable_objects(): print(obj)
对于上面的代码,它会生成以下列表:
<tf.Variable 'fully_connected/weights:0' shape=(100, 50) dtype=float32_ref><tf.Variable 'fully_connected/biases:0' shape=(50,) dtype=float32_ref><tf.Variable 'fully_connected_1/weights:0' shape=(50, 2) dtype=float32_ref><tf.Variable 'fully_connected_1/biases:0' shape=(2,) dtype=float32_ref><tf.Variable 'Variable:0' shape=() dtype=int32_ref><tf.Variable 'beta1_power:0' shape=() dtype=float32_ref><tf.Variable 'beta2_power:0' shape=() dtype=float32_ref><tf.Variable 'fully_connected/weights/Adam:0' shape=(100, 50) dtype=float32_ref><tf.Variable 'fully_connected/weights/Adam_1:0' shape=(100, 50) dtype=float32_ref><tf.Variable 'fully_connected/biases/Adam:0' shape=(50,) dtype=float32_ref><tf.Variable 'fully_connected/biases/Adam_1:0' shape=(50,) dtype=float32_ref><tf.Variable 'fully_connected_1/weights/Adam:0' shape=(50, 2) dtype=float32_ref><tf.Variable 'fully_connected_1/weights/Adam_1:0' shape=(50, 2) dtype=float32_ref><tf.Variable 'fully_connected_1/biases/Adam:0' shape=(2,) dtype=float32_ref><tf.Variable 'fully_connected_1/biases/Adam_1:0' shape=(2,) dtype=float32_ref>
这个列表中包含了来自fully_connected
层的变量以及来自Adam优化器的多个变量(参见这个问题)。请注意,这个列表中没有X
和Y
占位符,因此无需排除它们。当然,这些张量存在于元图中,但它们没有任何值,因此不可保存。
_all_saveable_objects()
列表是TensorFlow保存器默认保存的内容,如果没有明确提供变量。因此,你主要问题的答案很简单:
saver = tf.train.Saver() # 所有可保存的对象!with tf.Session() as sess: tf.global_variables_initializer().run() saver.save(sess, "...")
无法为tf.contrib.layers.fully_connected
函数提供名称(因此它被保存为fully_connected_1/...
),但建议你切换到tf.layers.dense
,它有一个name
参数。要了解为什么这是一个好主意,请查看这个和这个讨论。