我有一个预训练的Tensorflow检查点,其中所有参数都是float32数据类型。
如何将检查点参数加载为float16?或者是否有办法修改检查点的数据类型?
以下是我的代码片段,尝试将float32检查点加载到float16图中,结果得到了类型不匹配的错误。
import tensorflow as tfA = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)dense = tf.layers.dense(inputs=A, units=3)varis = tf.trainable_variables(scope=None)print(varis[1]) # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>assign = dict([(vari.name, vari) for vari in varis])saver = tf.train.Saver(assign)with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(dense)) save_path = saver.save(sess, "tmp.ckpt")tf.reset_default_graph()A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)dense = tf.layers.dense(inputs=A, units=3)varis = tf.trainable_variables(scope=None)print(varis[1]) # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>assign = dict([(vari.name, vari) for vari in varis])saver = tf.train.Saver(assign)with tf.Session() as sess: saver.restore(sess, "tmp.ckpt") print(sess.run(dense)) pass# errors:# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float# tensor_name = foo:0; expected dtype half does not equal original dtype float
回答:
稍微研究了一下保存器的工作原理,似乎可以通过builder
对象重新定义它们的构造。你可以例如有一个builder,它以tf.float32
加载值,然后将它们转换为变量的实际类型:
import tensorflow as tffrom tensorflow.python.training.saver import BaseSaverBuilderclass CastFromFloat32SaverBuilder(BaseSaverBuilder): # Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore def bulk_restore(self, filename_tensor, saveables, preferred_shard, restore_sequentially): from tensorflow.python.ops import io_ops restore_specs = [] for saveable in saveables: for spec in saveable.specs: restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) names, slices, dtypes = zip(*restore_specs) restore_dtypes = [tf.float32 for _ in dtypes] with tf.device("cpu:0"): restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes) return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]
请注意,这假设所有恢复的变量都是tf.float32
。如果需要,你可以根据你的使用情况适当调整builder,例如在构造函数中传递源类型或类型等。有了这个,你只需要在第二个保存器中使用上述builder,就可以让你的示例工作:
import tensorflow as tfwith tf.Graph().as_default(), tf.Session() as sess: A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32) dense = tf.layers.dense(inputs=A, units=3) varis = tf.trainable_variables(scope=None) assign = {vari.name: vari for vari in varis} saver = tf.train.Saver(assign) sess.run(tf.global_variables_initializer()) print('Value to save:') print(sess.run(dense)) save_path = saver.save(sess, "ckpt/tmp.ckpt")with tf.Graph().as_default(), tf.Session() as sess: A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16) dense = tf.layers.dense(inputs=A, units=3) varis = tf.trainable_variables(scope=None) assign = {vari.name: vari for vari in varis} saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder()) saver.restore(sess, "ckpt/tmp.ckpt") print('Restored value:') print(sess.run(dense))
输出:
Value to save:[[ 0.50589913 0.33701038 -0.11597633] [ 0.27372625 0.27724823 0.49825498] [ 1.0897961 -0.29577428 -0.9173869 ]]Restored value:[[ 0.506 0.337 -0.11597] [ 0.2737 0.2773 0.4983 ] [ 1.09 -0.296 -0.9175 ]]