从检查点恢复时,如何更改参数的数据类型?

我有一个预训练的Tensorflow检查点,其中所有参数都是float32数据类型。

如何将检查点参数加载为float16?或者是否有办法修改检查点的数据类型?

以下是我的代码片段,尝试将float32检查点加载到float16图中,结果得到了类型不匹配的错误。

import tensorflow as tfA = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)dense = tf.layers.dense(inputs=A, units=3)varis = tf.trainable_variables(scope=None)print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>assign = dict([(vari.name, vari) for vari in varis])saver = tf.train.Saver(assign)with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    print(sess.run(dense))    save_path = saver.save(sess, "tmp.ckpt")tf.reset_default_graph()A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)dense = tf.layers.dense(inputs=A, units=3)varis = tf.trainable_variables(scope=None)print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>assign = dict([(vari.name, vari) for vari in varis])saver = tf.train.Saver(assign)with tf.Session() as sess:    saver.restore(sess, "tmp.ckpt")    print(sess.run(dense))    pass# errors:# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float# tensor_name = foo:0; expected dtype half does not equal original dtype float

回答:

稍微研究了一下保存器的工作原理,似乎可以通过builder对象重新定义它们的构造。你可以例如有一个builder,它以tf.float32加载值,然后将它们转换为变量的实际类型:

import tensorflow as tffrom tensorflow.python.training.saver import BaseSaverBuilderclass CastFromFloat32SaverBuilder(BaseSaverBuilder):  # Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore  def bulk_restore(self, filename_tensor, saveables, preferred_shard,                   restore_sequentially):    from tensorflow.python.ops import io_ops    restore_specs = []    for saveable in saveables:      for spec in saveable.specs:        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))    names, slices, dtypes = zip(*restore_specs)    restore_dtypes = [tf.float32 for _ in dtypes]    with tf.device("cpu:0"):      restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)      return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]

请注意,这假设所有恢复的变量都是tf.float32。如果需要,你可以根据你的使用情况适当调整builder,例如在构造函数中传递源类型或类型等。有了这个,你只需要在第二个保存器中使用上述builder,就可以让你的示例工作:

import tensorflow as tfwith tf.Graph().as_default(), tf.Session() as sess:    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)    dense = tf.layers.dense(inputs=A, units=3)    varis = tf.trainable_variables(scope=None)    assign = {vari.name: vari for vari in varis}    saver = tf.train.Saver(assign)    sess.run(tf.global_variables_initializer())    print('Value to save:')    print(sess.run(dense))    save_path = saver.save(sess, "ckpt/tmp.ckpt")with tf.Graph().as_default(), tf.Session() as sess:    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)    dense = tf.layers.dense(inputs=A, units=3)    varis = tf.trainable_variables(scope=None)    assign = {vari.name: vari for vari in varis}    saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder())    saver.restore(sess, "ckpt/tmp.ckpt")    print('Restored value:')    print(sess.run(dense))

输出:

Value to save:[[ 0.50589913  0.33701038 -0.11597633] [ 0.27372625  0.27724823  0.49825498] [ 1.0897961  -0.29577428 -0.9173869 ]]Restored value:[[ 0.506    0.337   -0.11597] [ 0.2737   0.2773   0.4983 ] [ 1.09    -0.296   -0.9175 ]]

Related Posts

为什么我们在K-means聚类方法中使用kmeans.fit函数?

我在一个视频中使用K-means聚类技术,但我不明白为…

如何获取Keras中ImageDataGenerator的.flow_from_directory函数扫描的类名?

我想制作一个用户友好的GUI图像分类器,用户只需指向数…

如何查看每个词的tf-idf得分

我试图了解文档中每个词的tf-idf得分。然而,它只返…

如何修复 ‘ValueError: Found input variables with inconsistent numbers of samples: [32979, 21602]’?

我在制作一个用于情感分析的逻辑回归模型时遇到了这个问题…

如何向神经网络输入两个不同大小的输入?

我想向神经网络输入两个数据集。第一个数据集(元素)具有…

逻辑回归与机器学习有何关联

我们正在开会讨论聘请一位我们信任的顾问来做机器学习。一…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注