如何在Tensorflow中恢复具有不同名称但相同形状的权重?

我在Tensorflow中有多个架构。其中一些共享某些部分的设计。

我想训练其中一个网络,并在另一个网络中使用相似层面的训练权重。

目前,我能够保存我想保存的权重,并在变量命名完全相同的架构中重新加载它们。

然而,当两个网络中的权重名称不同时,无法恢复。我在第一个网络中使用了这样的命名约定:

  • selector_network/c2w/var1

在第二个网络中,我使用的是这样的命名:

  • joint_network/c2w/var1

除此之外,这些变量在形状上是相似的。是否有可能在重新加载时更改名称,或者告诉Tensorflow这些变量应该放在哪里?

编辑: 我找到了@某人的一个脚本,它允许重命名Tensorflow检查点的变量:tensorflow_rename_variables

它不起作用。我得到了以下错误:

ValueError: Couldn't find 'checkpoint' file or checkpoints in given directory ./joint_pos_tagger_lemmatizer/fi/

回答:

tf.train.Saver 使用var_list参数的字典内置支持这种情况。该字典将检查点文件中对象的名称映射到您想要恢复的变量。

如果您想用“selector网络”的检查点来恢复“joint网络”,您可以这样做:

# var1 是您想要恢复的变量saver = tf.train.Saver(var_list={'selector_network/c2w/var1': var1})saver.restore(...)

如果您想恢复更多变量,只需扩展该字典即可。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注