我在 TensorFlow 中训练了一个模型,当我从训练好的模型中进行评估时,它运行得非常完美。
然而,在不同阶段,我会保存一个检查点,然后加载这个检查点来进行评估。加载后的网络只会输出 NaN 值。
使用 tfdbg 并在输入时运行“has_inf_or_nan”过滤器,最终显示网络中首次出现的 NaN 值出现在一个批量归一化层的 moving_mean 和 moving_variance 变量中。
保存操作使用以下代码进行:
with self.graph.as_default(): if not self.saver: self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=10000) save_dir = create_save_dir(path, name) return self.saver.save(self.session, save_dir, global_step=iteration, write_meta_graph=True)
加载操作使用以下代码进行:
with self.graph.as_default(): save_dir = create_save_dir(load_dir, load_name) self.saver = tf.train.import_meta_graph(save_dir + "-" + str(iteration) + ".meta") self.saver.restore(self.session, save_dir + "-" + str(iteration)) self.input_layer = self.graph.get_tensor_by_name("network/input_layer:0") self.out_policy_layer = self.graph.get_tensor_by_name("network/out_policy_layer:0") self.out_value_layer = self.graph.get_tensor_by_name("network/out_value_layer/Tanh:0") self.is_training = self.graph.get_tensor_by_name("network/is_training:0")
再次强调,让我怀疑保存/加载程序存在问题的关键是,如果我通过已经训练好的网络运行,网络会输出有效结果。只有当我通过加载的网络运行时,才会得到 NaN 值。
编辑补充,我的批量归一化是使用以下代码创建的:
def _conv_block(self, name, input_layer, filter_size, num_input_channels, num_output_channels): weights = self._create_weights_for_layer(f"{name}_weights", shape=[filter_size[0], filter_size[1], num_input_channels, num_output_channels], use_regularizer=self._config.l2_regularizer) conv = self._conv2d(input_layer, weights, strides=[1, 1, 1, 1], padding="SAME", name=f"{name}_conv") bn = self._conv_batch_norm(conv, f"{name}_batch_norm") return tf.nn.relu(bn, name=f"{name}_act")def _conv_batch_norm(self, input_layer, name): return tf.layers.batch_normalization(input_layer, axis=CHANNEL_SHAPE_INDEX, center=True, scale=True, training=self.is_training, momentum=self._config.batch_norm_momentum, name=name)
回答:
简而言之,如果你的模型中随机出现 NaN 值,并且你已经检查了常见的原因,请考虑你的硬件可能出现故障,不要浪费数百个小时。
这是由一块内存即将损坏的显卡引起的。我以为这是软件问题,并没有想到这一点。训练是在另一台电脑上进行的,没有遇到任何问题。
我们最终意识到我们有硬件问题的情况是这样的。我们之前在旧电脑上经历过随机出现的 NaN 值。我们花了数百个小时调试模型,以为是模型的问题。在我们做出更改后,我们也恰好换到了一台升级的电脑,所以我们认为我们的更改解决了问题,因为 NaN 值停止了。然后一个月后我们开始使用那台旧电脑进行评估,又遇到了 NaN 值。那时我发了这个帖子。之后不久我意识到可能存在硬件问题。