我正在尝试训练PyTorch的torchvision.models.detection.fasterrcnn_resnet50_fpn
模型,以便在我的图像中检测对象。
根据文档,该模型需要一个图像列表和一个包含’boxes’和’labels’作为键的字典列表。因此,我的__getitem__()
方法如下所示:
def __getitem__(self, idx):
# 加载图像
_, img = self.images[idx].getImage()
img = Image.fromarray(img, mode='RGB')
objects = self.images[idx].objects
boxes = []
labels = []
for o in objects:
# 将边界框添加到boxes中
boxes.append([o.x, o.y, o.x+o.width, o.y+o.height])
# 添加class_id的第四个字符,即灯光数量(1-4)
labels.append(int(str(o.class_id)[3]))
# 将所有内容转换为torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
# 目前,transforms仅包含transforms.Compose([transforms.ToTensor()])
if self.transforms is not None:
img = self.transforms(img)
return img, target
据我所知,它返回的正是所需的内容。我的数据加载器如下所示:
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=4, shuffle=False, num_workers=2)
然而,当执行到这一阶段时:for images, targets in dataloaders[phase]:
,它会引发以下错误:
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 7 in dimension 1 at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensor.cpp:689
有人能指点我正确的方向吗?
回答:
@[隐藏人名] 是对的,我需要编写一个单独的collate函数,以便网络能够按预期接收数据。在我的情况下,我只需要绕过默认函数即可。