我需要冻结模型的某些部分,并只训练特定的变量。现在,使用低级API,我可以直接将var_list
传递给tf.train.Optimizer.minimize
方法。但是,当我使用TensorFlow Estimator时,我只能传递优化器本身,然后在Estimator的内部循环中使用它来最小化损失。
我能想到的唯一解决方案是定义一个自定义优化器,并覆盖Optimizer.minimize
方法。类似这样:
def minimize(self, *args, **kwargs): print("Inside...") if not kwargs['var_list']: kwargs['var_list'] = self.var_list return super(MyOptimizer, self).minimize(*args, **kwargs)
现在,我期望在每个训练步骤中看到“Inside…”这句话被打印出来;特别是当我看到模型正常训练时。这似乎表明我的minimize
函数被完全忽略了,我似乎无法弄清楚为什么会这样。
那么,覆盖minimize
是否正确,或者使用Estimator有更好的方法来做到这一点?
回答:
您可以简单地通过指定model_fn
函数来创建自定义的Estimator
def model_fn(features, labels, mode): logits = model_architecture(features) loss = loss_function(logits, labels) if mode == tf.estimator.ModeKeys.TRAIN: optimizer = optimizer train_op = optimizer.minimize(loss=loss, global_step=global_step, var_list=variables_to_minimize) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)