我正在尝试使用TensorFlow和Python创建一个图像分类器。然而,我在索引超出范围时遇到了这个奇怪的错误。程序应该读取文件,读取前三个字母,并根据是猫还是狗进行训练。
import cv2import numpy as npimport osfrom random import shufflefrom tqdm import tqdmTRAIN_DIR = 'C:\\Users\\cward\\Desktop\\images\\train'TEST_DIR = 'C:\\Users\\cward\\Desktop\\images\\test'IMG_SIZE = 50LR = 1e-3MODEL_NAME = 'dogsvscats-{}-{}.model'.format(LR, '2conv-basic')def label_img(img): word_label = img.split('.')[-2] if word_label == 'cat': return[1,0] elif word_label == 'dog': return[0,1]def create_train_data(): training_data = [] for img in tqdm(os.listdir(TRAIN_DIR)): label = label_img(img) path = os.path.join(TRAIN_DIR, img) img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (IMG_SIZE,IMG_SIZE)) training_data.append([np.array(img), np.array(label)]) shuffle(traning_data) np.save('train_data.npy', traning_data) return training__datadef process_test_data(): testing_data = [] for img in tqdm(os.listdir(TRAIN_DIR)): path = os.path.join(TRAIN_DIR, img) img_num = img.split('.')[0] img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (IMG_SIZE,IMG_SIZE)) testing_data.append([np.array(img), img_num]) np.save('test_data.npy',testing_data) return testing_datatrain_data = create_train_data()
这是错误信息:
---------------------------------------------------------------------------IndexError Traceback (most recent call last)<ipython-input-34-40719067ea74> in <module>()----> 1 train_data = create_train_data()<ipython-input-32-88b70eb23645> in create_train_data() 2 training_data = [] 3 for img in tqdm(os.listdir(TRAIN_DIR)):----> 4 label = label_img(img) 5 path = os.path.join(TRAIN_DIR, img) 6 img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (IMG_SIZE,IMG_SIZE))<ipython-input-31-82bc72a4ed99> in label_img(img) 1 def label_img(img):----> 2 word_label = img.split('.')[-2] 3 if word_label == 'cat': return[1,0] 4 elif word_label == 'dog': return[0,1]IndexError: list index out of range
我是Python新手,请原谅我的糟糕格式!
回答:
错误信息显示img.split(‘.’)的长度小于2
你在TRAIN_DIR中有子目录吗?这会触发此错误。我个人建议你首先尝试:
try: label = label_img(img)except IndexError: print(img) continue
这应该会打印出所有会触发错误的img值。可能的情况是某个图像文件缺少扩展名。一旦你确定了错误,并修复了任何文件,你可以这样做:
if len(img.split('.')) < 2: continuelable = label_img(img)
这样可以使代码忽略会触发错误的文件。这样,即使你有子目录,你的代码仍然可以工作(尽管子目录中的图像仍会被忽略)