问题描述
我尝试使用PyTorch自定义数据集加载图像数据,但收到了下面的错误信息。在错误发生后,我检查了数据,发现我的图像集包含两种尺寸(512,512,3)和(1024,1024)。我的假设是这个错误与此有关。
注意:代码能够读取一些图像,但对其他图像抛出错误信息。
问题
-
如何对这种图像数据进行预处理以进行训练?
-
是否有其他导致错误信息的原因?
错误信息
KeyError Traceback (most recent call last)<ipython-input-163-aa3385de8026> in <module>----> 1 train_features, train_labels = next(iter(train_dataloader)) 2 print(f"Feature batch shape: {train_features.size()}") 3 print(f"Labels batch shape: {train_labels.size()}") 4 img = train_features[0].squeeze() 5 label = train_labels[0] ~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils /data/dataloader.py in __next__(self)519 if self._sampler_iter is None:520 self._reset()521 data = self._next_data()522 self._num_yielded += 1523 if self._dataset_kind == _DatasetKind.Iterable and \~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)1201 else:1202 del self._task_info[idx]1203 return self._process_data(data)1204 1205 def _try_put_index(self):~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)1227 self._try_put_index()1228 if isinstance(data, ExceptionWrapper):1229 data.reraise()1230 return data1231 ~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_utils.py in reraise(self)423 # have message field424 raise self.exc_type(message=msg)425 raise self.exc_type(msg)426 427 KeyError: Caught KeyError in DataLoader worker process 0.Original Traceback (most recent call last):File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas /core/indexes/base.py", line 2898, in get_locreturn self._engine.get_loc(casted_key)File "pandas/_libs/index.pyx", line 70, in pandas._libs.index.IndexEngine.get_locFile "pandas/_libs/index.pyx", line 101, in pandas._libs.index.IndexEngine.get_locFile "pandas/_libs/hashtable_class_helper.pxi", line 1032, in pandas._libs.hashtable.Int64HashTable.get_itemFile "pandas/_libs/hashtable_class_helper.pxi", line 1039, in pandas._libs.hashtable.Int64HashTable.get_itemKeyError: 16481The above exception was the direct cause of the following exception:Traceback (most recent call last):File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loopdata = fetcher.fetch(index)File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetchdata = [self.dataset[idx] for idx in possibly_batched_index]File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>data = [self.dataset[idx] for idx in possibly_batched_index]File "<ipython-input-161-f38b78d77dcb>", line 19, in __getitem__img_path =os.path.join(self.img_dir,self.image_ids[idx])File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/series.py", line 882, in __getitem__return self._get_value(key)File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/series.py", line 990, in _get_valueloc = self.index.get_loc(label)File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 2900, in get_locraise KeyError(key) from errKeyError: 16481
代码
from torchvision.io import read_imageimport torchfrom torchvision import transformsfrom sklearn.model_selection import train_test_splitfrom torch.utils.data import Datasetclass CustomImageDataset(Dataset): def __init__(self, dataset, transforms=None, target_transforms=None): #self.train_data = pd.read_csv("Data/train_data.csv") self.image_ids = dataset.image_id self.image_labels = dataset.label self.img_dir = 'Data/images' self.transforms = transforms self.target_transforms = target_transforms def __len__(self): return len(self.image_ids) def __getitem__(self,idx): # image path img_path =os.path.join(self.img_dir,self.image_ids[idx]) # image image = read_image(img_path) label = self.image_labels[idx] # transform image if self.transforms: image = self.transforms(image) # transform target if self.target_transforms: label = self.target_transforms(label) return image, label
train_data
是包含图像ID和标签信息的CSV文件的pandas对象。
from sklearn.model_selection import train_test_splitX_train, X_test = train_test_split(train_data, test_size=0.1, random_state=42)train_df = CustomImageDataset(X_train)train_dataloader = torch.utils.data.DataLoader( train_df, batch_size=64, num_workers=1, shuffle=True)
回答:
发现了代码中的问题。
PyTorch自定义数据加载器函数”getitem“使用idx来检索数据,我的猜测是,它从len函数中知道idx的范围,例如:从0到数据集中的行数。
在我这里,我已经有一个带有idx作为列之一的pandas数据集(train_data)。当我随机将其分割成X_train和X_test时,一些数据行连同idx一起被移动到X_test中。
现在,当我将X_train发送到自定义数据加载器时,它试图使用idx来获取行的image_id,而那个idx恰好在X_test数据集中。这导致了错误,即keyerror: 16481,即idx=16481的行不在X_train数据集中。它在分割过程中被移动到了X_test中。
呼…