我在比较模型预测结果与训练集标签时遇到了问题。我使用的数组形状如下:
训练集 (200000, 28, 28) (200000,)
验证集 (10000, 28, 28) (10000,)
测试集 (10000, 28, 28) (10000,)
然而,在使用以下函数检查准确率时:
def accuracy(predictions, labels): return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1)) / predictions.shape[0])
它给出的结果是:
C:\Users\***\Anaconda3\lib\site-packages\ipykernel_launcher.py:5: DeprecationWarning: elementwise == comparison failed; this will raise an error in the future.”””
并且所有数据集的准确率都显示为0%。
我认为我们不能使用’==’来比较数组。那么应该如何正确地比较这些数组呢?
回答:
我认为错误发生在这个表达式中:
np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1))
您能告诉我们关于predictions
和labels
这两个数组的一些信息吗?通常的信息 – 数据类型、形状、一些样本值。也许可以进一步展示每个数组的np.argmax(...)
结果。
在numpy
中,您可以比较大小相同的数组,但对于大小不匹配的数组的比较变得更加严格:
In [522]: np.arange(10)==np.arange(5,15)Out[522]: array([False, False, False, False, False, False, False, False, False, False], dtype=bool)In [523]: np.arange(10)==np.arange(5,14)/usr/local/bin/ipython3:1: DeprecationWarning: elementwise == comparison failed; this will raise an error in the future. #!/usr/bin/python3Out[523]: False