我遇到一个问题,无论我传入什么标签/预测值,TF.Metrics.Mean_Squared_Error总是返回0值。
以下是复现问题的代码:
a = tf.constant([0,0,0,0])b = tf.constant([1,1,1,1])mse, update = tf.metrics.mean_squared_error(a,b)sess = tf.Session()sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())mse.eval(session=sess)
%% 返回0.0
回答:
我不知道为什么会这样,但实际上你需要先运行update
,mse的内部状态才会考虑你的数据:
a = tf.constant([0,0,0,0])b = tf.constant([1,1,1,1])mse, update = tf.metrics.mean_squared_error(a,b)sess = tf.Session()sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())mse.eval(session=sess) # 返回0.0,初始的MSE值update.eval(session=sess) # 返回1.0,更新的mse值update.eval(session=sess) # 返回1.0,这是0.0+1.0,更新后的mse值
tf.metrics.mean_squared_error()
旨在计算整个数据集的MSE,例如,如果你想要独立批次的结果,你不应该使用它。对于这种情况,你可以使用tf.losses.mean_squared_error(a, b, loss_collection=None)
。