如何在Tensorflow训练过程中打印梯度?

为了调试Tensorflow模型,我需要查看梯度是否发生变化,或者其中是否存在NaN值。简单地在Tensorflow中打印变量是行不通的,因为你看到的只是:

 <tf.Variable 'Model/embedding:0' shape=(8182, 100) dtype=float32_ref>

我尝试使用tf.Print类,但无法使其工作,我怀疑它是否真的可以这样使用。在我的模型中,我有一个训练循环,它打印每个epoch的损失值:

def run_epoch(session, model, eval_op=None, verbose=False):    costs = 0.0    iters = 0    state = session.run(model.initial_state)    fetches = {            "cost": model.cost,            "final_state": model.final_state,    }    if eval_op is not None:        fetches["eval_op"] = eval_op    for step in range(model.input.epoch_size):        feed_dict = {}        for i, (c, h) in enumerate(model.initial_state):            feed_dict[c] = state[i].c            feed_dict[h] = state[i].h        vals = session.run(fetches, feed_dict)        cost = vals["cost"]        state = vals["final_state"]        costs += cost        iters += model.input.num_steps    print("Loss:", costs)    return costs

print(model.gradients[0][1])插入此函数中不起作用,因此我在打印损失值后尝试使用以下代码:

grads = model.gradients[0][1]x = tf.Print(grads, [grads])session.run(x)

但我得到了以下错误消息:

ValueError: Fetch argument <tf.Tensor 'mul:0' shape=(8182, 100) dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("mul:0", shape=(8182, 100), dtype=float32) is not an element of this graph.)

这有道理,因为tf.Print确实不是图的一部分。所以,我尝试在实际图中的损失计算后使用tf.Print,但这也没有起作用,我仍然得到Tensor("Train/Model/mul:0", shape=(8182, 100), dtype=float32)

如何在Tensorflow的训练循环内部打印梯度变量?


回答:

根据我的经验,在tensorflow中查看梯度流的最佳方法不是使用tf.Print,而是使用tensorboard。以下是我在另一个问题中使用的示例代码,当时梯度是学习中的关键问题:

for g, v in grads_and_vars:  tf.summary.histogram(v.name, v)  tf.summary.histogram(v.name + '_grad', g)merged = tf.summary.merge_all()writer = tf.summary.FileWriter('train_log_layer', tf.get_default_graph())..._, summary = sess.run([train_op, merged], feed_dict={I: 2*np.random.rand(1, 1)-1})if i % 10 == 0:  writer.add_summary(summary, global_step=i)

这将向你展示梯度随时间的分布。顺便说一句,要检查NaN值,tensorflow中有一个专门的函数:tf.is_nan。通常,你不需要检查梯度是否为NaN:当这种情况发生时,变量也会爆炸,这在tensorboard上会非常明显。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注