我认为在以下代码的第53行存在一个错误:
代码显示:
return_images = torch.cat(return_images, 0) # 收集所有图像并返回
你觉得应该如何正确地修改这一行?在浏览了代码之后,我遗憾地无法确定这一行的具体功能,但我认为我理解了直到第52行的其他部分。
回答:
torch.cat
的第一个参数应该是一个张量序列,而不是单个张量。因此,你应该这样传递参数:
torch.cat([tensor_1, tensor_2, tensor_3]) # 正确的方法
而不是这样:
torch.cat(tensor_1, tensor_2, tensor_3) # 不正确的方法
在你链接的代码中,他们正在形成一个名为return_images
的列表,其中包含了多个张量。
np.concatenate
也有类似的行为,PyTorch的设计者可能从中借鉴了这个选择。(“cat”是“concatenate”的缩写,后者较难拼写!)