正确设置数据加载器以加速训练PyTorch中的fasterrcnn-resnet50进行对象检测

我正在尝试训练PyTorch的torchvision.models.detection.fasterrcnn_resnet50_fpn模型,以便在我的图像中检测对象。

根据文档,该模型需要一个图像列表和一个包含’boxes’和’labels’作为键的字典列表。因此,我的__getitem__()方法如下所示:

def __getitem__(self, idx):
    # 加载图像
    _, img = self.images[idx].getImage()
    img = Image.fromarray(img, mode='RGB')
    objects = self.images[idx].objects
    boxes = []
    labels = []
    for o in objects:
        # 将边界框添加到boxes中
        boxes.append([o.x, o.y, o.x+o.width, o.y+o.height])
        # 添加class_id的第四个字符,即灯光数量(1-4)
        labels.append(int(str(o.class_id)[3]))
    # 将所有内容转换为torch.Tensor
    boxes = torch.as_tensor(boxes, dtype=torch.float32)
    labels = torch.as_tensor(labels, dtype=torch.int64)
    target = {}
    target["boxes"] = boxes
    target["labels"] = labels
    # 目前,transforms仅包含transforms.Compose([transforms.ToTensor()])
    if self.transforms is not None:
        img = self.transforms(img)
    return img, target

据我所知,它返回的正是所需的内容。我的数据加载器如下所示:

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=4, shuffle=False, num_workers=2)

然而,当执行到这一阶段时:for images, targets in dataloaders[phase]:,它会引发以下错误:

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 7 in dimension 1 at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensor.cpp:689

有人能指点我正确的方向吗?


回答:

@[隐藏人名] 是对的,我需要编写一个单独的collate函数,以便网络能够按预期接收数据。在我的情况下,我只需要绕过默认函数即可。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注