因此,我尝试了这个教程(11:58分钟处),试图在我的CNN上实现,它包含数据集中10个物种。
我没有遇到加载数据时的错误
DATADIR = "dataset"CATEGORIES = ["Dendrobium_crumenatum","Grammatophyllum_speciosum", "Coelogyne_swaniana", "Bulbophyllum_purpurascens", "Agrostophyllum_stipulatum", "Spathoglottis_plicata", "Phalaenopsis_amabilis", "Nabaluia_angustifolia", "Habenaria_rhodocheila_hance"] 这里是输出示例[![enter image description here][2]][2]
然后是下一部分
training_data = []def create_training_data(): for category in CATEGORIES: path = os.path.join(DATADIR,category) class_num = CATEGORIES.index(category) for img in os.listdir(path): try: img_array = cv2.imread(os.path.join(path,img),cv2.IMREAD_GRAYSCALE) new_array = cv2.resize(img_array, (IMG_SIZE,IMG_SIZE)) training_data.append([new,array, class_num]) except Exception as e: passcreate_training_data()
但是当我打印prin(len(training_data))
我得到了这样的输出
0
当我尝试
import randomrandom.shuffle(training_data)for sample in training_data[:10]: print (sample[1])
它没有显示任何输出。这是否意味着我的训练数据为空?还是因为使用了类别的索引?因为我使用了10个类,而教程中使用了2个类。
回答:
使你的training_data成为全局变量
training_data =[]def create_training_data(): global training_data for category in CATEGORIES: path = os.path.join(DATADIR,category) class_num = CATEGORIES.index(category) for img in os.listdir(path): try: img_array = cv2.imread(os.path.join(path,img),cv2.IMREAD_GRAYSCALE) new_array = cv2.resize(img_array,(IMG_SIZE, IMG_SIZE)) training_data.append([new_array,class_num]) except Exception as e: passcreate_training_data()