TensorFlow中pytorch的autograd.detect_anomaly等效功能

我在尝试调试我的TensorFlow代码,该代码在大约30个epoch后突然产生了NaN损失。你可以在这个Stack Overflow问题中找到我具体的问题和尝试过的方法。

我在训练过程中监控了每个mini-batch中所有层的权重,发现尽管前一次迭代中的所有权重值都小于1(我已经将kernel_constraint的max_norm设置为1),权重突然跳转到NaN。这使得找出哪个操作是罪魁祸首变得非常困难。

Pytorch有一个很酷的调试方法torch.autograd.detect_anomaly,它会在产生NaN值的任何反向计算时产生错误并显示调用堆栈。这使得调试代码变得很容易。

TensorFlow中有类似的功能吗?如果没有,你能建议一种调试方法吗?


回答:

TensorFlow中确实有一个类似的调试工具。请参见tf.debugging.check_numerics

这可以用来追踪训练过程中产生infnan值的张量。一旦发现这样的值,TensorFlow会产生一个InvalidArgumentError

tf.debugging.check_numerics(LayerN, "LayerN is producing nans!")

如果张量LayerN包含NaN,你会得到如下错误:

Traceback (most recent call last):  File "trainer.py", line 506, in <module>    worker.train_model()  File "trainer.py", line 211, in train_model    l, tmae = train_step(*batch)  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 828, in __call__    result = self._call(*args, **kwds)  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 855, in _call    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2943, in __call__    filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat    ctx, args, cancellation_manager=cancellation_manager))  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 560, in call    ctx=ctx)  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute    inputs, attrs, num_outputs)tensorflow.python.framework.errors_impl.InvalidArgumentError:  LayerN is producing nans! : Tensor had NaN values

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注