我在Tensorflow中有多个架构。其中一些共享某些部分的设计。
我想训练其中一个网络,并在另一个网络中使用相似层面的训练权重。
目前,我能够保存我想保存的权重,并在变量命名完全相同的架构中重新加载它们。
然而,当两个网络中的权重名称不同时,无法恢复。我在第一个网络中使用了这样的命名约定:
- selector_network/c2w/var1
在第二个网络中,我使用的是这样的命名:
- joint_network/c2w/var1
除此之外,这些变量在形状上是相似的。是否有可能在重新加载时更改名称,或者告诉Tensorflow这些变量应该放在哪里?
编辑: 我找到了@某人的一个脚本,它允许重命名Tensorflow检查点的变量:tensorflow_rename_variables。
它不起作用。我得到了以下错误:
ValueError: Couldn't find 'checkpoint' file or checkpoints in given directory ./joint_pos_tagger_lemmatizer/fi/
回答:
tf.train.Saver 使用var_list
参数的字典内置支持这种情况。该字典将检查点文件中对象的名称映射到您想要恢复的变量。
如果您想用“selector网络”的检查点来恢复“joint网络”,您可以这样做:
# var1 是您想要恢复的变量saver = tf.train.Saver(var_list={'selector_network/c2w/var1': var1})saver.restore(...)
如果您想恢复更多变量,只需扩展该字典即可。