我正在从一个数据集(使用flow_from_dataframe创建)中迭代图像,并希望使用pyplot显示图像及其对应的标签
training_dataset = train_dataset_gen.flow_from_dataframe(dataframe=train_df, directory="./Data/train/", x_col="file", y_col="label", shuffle=True, batch_size=BATCH_SIZE, target_size=IMG_SIZE, class_mode="sparse", subset="training")
我能够显示图像,但标签显示为0和1的数组。
for _ in range(len(training_dataset.filenames)): image, label = training_dataset.next() # 从迭代器中显示图像 plt.imshow(image[0]) plt.title(training_dataset.) plt.show()
如何获取真实的标签值?
这个问题通过反转字典得到了解决:
class_names=training_dataset.class_indicesnew_dict={}for key, value in class_names.items(): new_dict[value]=key
回答:
ImageDataGenerator有一个属性.class_indices。它是一个字典,形式如下
{'class name': class index}
其中类名是字符串,表示您数据框标签列中的标签名称。类索引是与类名关联的整数。这是您在打印标签时看到的值(因为您使用了class_mode=’sparse’)。例如,在您的DataFrame中,如果您的标签是cat和dog,那么您的class_indices将是
{'cat':0, 'dog':1}
反转这个字典非常方便,如下所示
for key, value in class_dict.items(): new_dict[value]=key
现在new_dict的形式为{0:’cat’, 1:’dog’}。因此,您可以按以下方式修改代码
for _ in range(len(training_dataset.filenames)): image, label = training_dataset.next() # 从迭代器中显示图像 plt.imshow(image[0]) label_name=new_dict[label[0]] # 注意您只显示批次中的第一张图像 plt.title(label_name) plt.show()