优化TensorFlow Estimator API中模型的特定变量

我需要冻结模型的某些部分,并只训练特定的变量。现在,使用低级API,我可以直接将var_list传递给tf.train.Optimizer.minimize方法。但是,当我使用TensorFlow Estimator时,我只能传递优化器本身,然后在Estimator的内部循环中使用它来最小化损失。

我能想到的唯一解决方案是定义一个自定义优化器,并覆盖Optimizer.minimize方法。类似这样:

def minimize(self, *args, **kwargs):    print("Inside...")    if not kwargs['var_list']:       kwargs['var_list'] = self.var_list    return super(MyOptimizer, self).minimize(*args, **kwargs)

现在,我期望在每个训练步骤中看到“Inside…”这句话被打印出来;特别是当我看到模型正常训练时。这似乎表明我的minimize函数被完全忽略了,我似乎无法弄清楚为什么会这样。

那么,覆盖minimize是否正确,或者使用Estimator有更好的方法来做到这一点?


回答:

您可以简单地通过指定model_fn函数来创建自定义的Estimator

    def model_fn(features, labels, mode):      logits = model_architecture(features)      loss = loss_function(logits, labels)      if mode == tf.estimator.ModeKeys.TRAIN:        optimizer = optimizer        train_op = optimizer.minimize(loss=loss,                                       global_step=global_step,                                      var_list=variables_to_minimize)      return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注