我的CNN模型通过model.predict()
输出了以下结果:
Tensor("input_1:0", shape=(?, 2, 26, 1), dtype=float32)[9.9952221e-01 2.3613637e-04 1.9953270e-06 1.6922619e-05 2.2012556e-04 2.4441533e-07 3.5276526e-07 7.4913805e-07 4.0657511e-07 8.7760031e-07]
我想从这个NumPy数组中获取最大值的索引。我目前尝试使用以下代码(x
是上面的数组):
result = x.index(max(x))
然而,这会引发一个错误,称这种数据类型不支持.index
方法?
回答:
你可以简单地使用np.argmax
函数:
import numpy as nppreds = model.predict(test_data)pred_class = np.argmax(preds, axis=-1)