我在为 DCNN 模型编写分类报告时遇到了一个错误。我的代码是
from sklearn.metrics import confusion_matrixtest = ImageDataGenerator()test_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)test_data = test_generator.flow_from_directory(directory="/content/dataset/test",target_size=IMAGE_SHAPE , color_mode="rgb" , class_mode='categorical' , batch_size=1 , shuffle = False )test_data.reset()predicted_class_indices=np.argmax(pred,axis=1)cm = confusion_matrix(test_labels, predictions.argmax(axis=1))
错误:
AttributeError: 'list' 对象没有属性 'argmax'
回答:
显然,您的 predictions
是一个 Python 列表,而列表没有 argmax
属性;您需要使用 Numpy 函数 argmax()
:
predictions = [[0.1, 0.9], [0.8, 0.2]] # 示例数据y_pred_binary = predictions.argmax(axis=1)# AttributeError: 'list' 对象没有属性 'argmax'# 使用 Numpy:import numpy as npy_pred_binary = np.argmax(predictions, axis=1)y_pred_binary# array([1, 0])