从检查点恢复时,如何更改参数的数据类型?

我有一个预训练的Tensorflow检查点,其中所有参数都是float32数据类型。

如何将检查点参数加载为float16?或者是否有办法修改检查点的数据类型?

以下是我的代码片段,尝试将float32检查点加载到float16图中,结果得到了类型不匹配的错误。

import tensorflow as tfA = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)dense = tf.layers.dense(inputs=A, units=3)varis = tf.trainable_variables(scope=None)print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>assign = dict([(vari.name, vari) for vari in varis])saver = tf.train.Saver(assign)with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    print(sess.run(dense))    save_path = saver.save(sess, "tmp.ckpt")tf.reset_default_graph()A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)dense = tf.layers.dense(inputs=A, units=3)varis = tf.trainable_variables(scope=None)print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>assign = dict([(vari.name, vari) for vari in varis])saver = tf.train.Saver(assign)with tf.Session() as sess:    saver.restore(sess, "tmp.ckpt")    print(sess.run(dense))    pass# errors:# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float# tensor_name = foo:0; expected dtype half does not equal original dtype float

回答:

稍微研究了一下保存器的工作原理,似乎可以通过builder对象重新定义它们的构造。你可以例如有一个builder,它以tf.float32加载值,然后将它们转换为变量的实际类型:

import tensorflow as tffrom tensorflow.python.training.saver import BaseSaverBuilderclass CastFromFloat32SaverBuilder(BaseSaverBuilder):  # Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restore  def bulk_restore(self, filename_tensor, saveables, preferred_shard,                   restore_sequentially):    from tensorflow.python.ops import io_ops    restore_specs = []    for saveable in saveables:      for spec in saveable.specs:        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))    names, slices, dtypes = zip(*restore_specs)    restore_dtypes = [tf.float32 for _ in dtypes]    with tf.device("cpu:0"):      restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)      return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]

请注意,这假设所有恢复的变量都是tf.float32。如果需要,你可以根据你的使用情况适当调整builder,例如在构造函数中传递源类型或类型等。有了这个,你只需要在第二个保存器中使用上述builder,就可以让你的示例工作:

import tensorflow as tfwith tf.Graph().as_default(), tf.Session() as sess:    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)    dense = tf.layers.dense(inputs=A, units=3)    varis = tf.trainable_variables(scope=None)    assign = {vari.name: vari for vari in varis}    saver = tf.train.Saver(assign)    sess.run(tf.global_variables_initializer())    print('Value to save:')    print(sess.run(dense))    save_path = saver.save(sess, "ckpt/tmp.ckpt")with tf.Graph().as_default(), tf.Session() as sess:    A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)    dense = tf.layers.dense(inputs=A, units=3)    varis = tf.trainable_variables(scope=None)    assign = {vari.name: vari for vari in varis}    saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder())    saver.restore(sess, "ckpt/tmp.ckpt")    print('Restored value:')    print(sess.run(dense))

输出:

Value to save:[[ 0.50589913  0.33701038 -0.11597633] [ 0.27372625  0.27724823  0.49825498] [ 1.0897961  -0.29577428 -0.9173869 ]]Restored value:[[ 0.506    0.337   -0.11597] [ 0.2737   0.2773   0.4983 ] [ 1.09    -0.296   -0.9175 ]]

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注