使用Tensorflow自动化迁移学习的方法?

我正在使用Tensorflow构建和训练几个神经网络。这些网络在相关任务(自然语言处理)上进行监督学习。

所有我的神经网络的共同点是它们共享一些早期的层(有些共享2层,其他更多)。

我希望能够共享从一个架构中训练好的公共层的权重,来初始化另一个架构。

目前我的做法是每次想要转移权重时,我都会编写一个单独的(临时)的代码片段。这使得我的项目变得杂乱且耗时。

有没有人知道一种可以自动化权重转移过程的方法?例如,自动检测公共层,然后初始化相应的权重。


回答:

你可以为感兴趣的变量集创建一个tf.Saver,这样你就可以在另一个图中恢复这些变量,只要它们的名称相同。你可以使用集合来存储这些变量,然后为集合创建saver:

TRANSFERABLE_VARIABLES = "transferable_variable"# ...my_var = tf.get_variable(...)tf.add_to_collection(TRANSFERABLE_VARIABLES, my_var)# ...saver = tf.Saver(tf.get_collection(TRANSFERABLE_VARIABLES), ...)

这样你就可以在一个图中调用save,在另一个图中调用restore来转移权重。

如果你不想将任何东西写入磁盘,那么我想除了手动复制/粘贴值之外,没有其他方法。然而,通过使用集合和完全相同的构建过程,这也可以在一定程度上实现自动化:

model1_graph = create_model1()model2_graph = create_model2()with model1_graph.as_default(), tf.Session() as sess:    # 训练...    # 获取学习到的权重    transferable_weights = sess.run(tf.get_collection(TRANSFERABLE_VARIABLES))with model2_graph.as_default(), tf.Session() as sess:    # 从另一个模型加载权重    for var, weight in zip(tf.get_collection(TRANSFERABLE_VARIABLES),                           transferable_weights):        var.load(weight, sess)    # 继续训练...

同样,这只有在公共层的构建方式相同的情况下才有效,因为集合中变量的顺序对于两个图来说应该是相同的。

更新:

如果你想确保恢复的变量不用于训练,你有几种选择,尽管它们可能都需要对你的代码进行更多的更改。一个trainable变量只是一个包含在集合tf.GrapKeys.TRAINABLE_VARIABLES中的变量,所以你可以在第二个图中创建转移变量时指定trainable=False,这样恢复过程应该会一样。如果你想更动态地自动完成这一点,这在某种程度上是可能的,但请记住:用于训练的变量列表必须在创建优化器之前已知,并且在创建新的优化器之前无法更改。了解这一点后,我认为没有任何解决方案不需要从第一个图中传递一个包含可转移变量名称的列表。例如:

with model1_graph.as_default():    transferable_names = [v.name for v in tf.get_collection(TRANSFERABLE_VARIABLES)]

然后,在第二个图的构建过程中,在定义模型之后,创建优化器之前,你可以这样做:

train_vars = [v for v in tf.get_collection(tf.GrapKeys.TRAINABLE_VARIABLES)              if v.name not in transferable_names]# 假设model2_graph是当前的默认图tf.get_default_graph().clear_collection(tf.GrapKeys.TRAINABLE_VARIABLES)for v in train_vars:    tf.add_to_collection(tf.GrapKeys.TRAINABLE_VARIABLES, v)# 创建优化器...

另一种选择是不修改集合tf.GrapKeys.TRAINABLE_VARIABLES,而是将你希望优化的变量列表(示例中的train_vars)作为参数var_list传递给优化器的minimize方法。原则上,我个人不太喜欢这种方法,因为我认为集合的内容应该与它们的语义目的相匹配(毕竟,代码的其他部分可能出于其他目的使用相同的集合),但这取决于具体情况,我想。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注