PyTorch: “KeyError: 在DataLoader工作进程0中捕获KeyError。”

问题描述

我尝试使用PyTorch自定义数据集加载图像数据,但收到了下面的错误信息。在错误发生后,我检查了数据,发现我的图像集包含两种尺寸(512,512,3)和(1024,1024)。我的假设是这个错误与此有关。

注意:代码能够读取一些图像,但对其他图像抛出错误信息。

问题

  1. 如何对这种图像数据进行预处理以进行训练?

  2. 是否有其他导致错误信息的原因?

错误信息

KeyError                                  Traceback (most recent call last)<ipython-input-163-aa3385de8026> in <module>----> 1 train_features, train_labels = next(iter(train_dataloader))  2 print(f"Feature batch shape: {train_features.size()}")  3 print(f"Labels batch shape: {train_labels.size()}")  4 img = train_features[0].squeeze()  5 label = train_labels[0] ~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils  /data/dataloader.py in __next__(self)519             if self._sampler_iter is None:520                 self._reset()521             data = self._next_data()522             self._num_yielded += 1523             if self._dataset_kind == _DatasetKind.Iterable and \~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)1201             else:1202                 del self._task_info[idx]1203                 return self._process_data(data)1204 1205     def _try_put_index(self):~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)1227         self._try_put_index()1228         if isinstance(data, ExceptionWrapper):1229             data.reraise()1230         return data1231 ~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/_utils.py in reraise(self)423             # have message field424             raise self.exc_type(message=msg)425         raise self.exc_type(msg)426 427 KeyError: Caught KeyError in DataLoader worker process 0.Original Traceback (most recent call last):File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas  /core/indexes/base.py", line 2898, in get_locreturn self._engine.get_loc(casted_key)File "pandas/_libs/index.pyx", line 70, in pandas._libs.index.IndexEngine.get_locFile "pandas/_libs/index.pyx", line 101, in pandas._libs.index.IndexEngine.get_locFile "pandas/_libs/hashtable_class_helper.pxi", line 1032, in    pandas._libs.hashtable.Int64HashTable.get_itemFile "pandas/_libs/hashtable_class_helper.pxi", line 1039, in   pandas._libs.hashtable.Int64HashTable.get_itemKeyError: 16481The above exception was the direct cause of the following exception:Traceback (most recent call last):File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loopdata = fetcher.fetch(index)File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetchdata = [self.dataset[idx] for idx in possibly_batched_index]File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>data = [self.dataset[idx] for idx in possibly_batched_index]File "<ipython-input-161-f38b78d77dcb>", line 19, in __getitem__img_path =os.path.join(self.img_dir,self.image_ids[idx])File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/series.py", line 882, in __getitem__return self._get_value(key)File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/series.py", line 990, in _get_valueloc = self.index.get_loc(label)File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pandas/core/indexes/base.py", line 2900, in get_locraise KeyError(key) from errKeyError: 16481

代码

from torchvision.io import read_imageimport torchfrom torchvision import transformsfrom sklearn.model_selection import train_test_splitfrom torch.utils.data import Datasetclass CustomImageDataset(Dataset):    def __init__(self, dataset, transforms=None, target_transforms=None):        #self.train_data = pd.read_csv("Data/train_data.csv")        self.image_ids = dataset.image_id        self.image_labels = dataset.label        self.img_dir = 'Data/images'        self.transforms = transforms        self.target_transforms = target_transforms    def __len__(self):        return len(self.image_ids)    def __getitem__(self,idx):        # image path        img_path =os.path.join(self.img_dir,self.image_ids[idx])        # image        image = read_image(img_path)        label = self.image_labels[idx]        # transform image        if self.transforms:             image = self.transforms(image)        # transform target        if self.target_transforms:             label = self.target_transforms(label)    return image, label

train_data 是包含图像ID和标签信息的CSV文件的pandas对象。

from sklearn.model_selection import train_test_splitX_train, X_test = train_test_split(train_data, test_size=0.1, random_state=42)train_df = CustomImageDataset(X_train)train_dataloader = torch.utils.data.DataLoader(    train_df,    batch_size=64,    num_workers=1,    shuffle=True)

回答:

发现了代码中的问题。

PyTorch自定义数据加载器函数”getitem“使用idx来检索数据,我的猜测是,它从len函数中知道idx的范围,例如:从0到数据集中的行数。

在我这里,我已经有一个带有idx作为列之一的pandas数据集(train_data)。当我随机将其分割成X_train和X_test时,一些数据行连同idx一起被移动到X_test中。

现在,当我将X_train发送到自定义数据加载器时,它试图使用idx来获取行的image_id,而那个idx恰好在X_test数据集中。这导致了错误,即keyerror: 16481,即idx=16481的行不在X_train数据集中。它在分割过程中被移动到了X_test中。

呼…

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注