通常我在自己的电脑上进行字母和数字识别工作,现在我想将项目移到 Colab 上,但不幸的是遇到了错误(您可以在下面看到错误)。经过一些调试后,我找到了导致错误的那一行代码。
transforms.RandomRotation(degrees=(90, -90))
我在下面写了一个简单的抽象代码来展示这个错误。这个代码在 Colab 上无法工作,但在我的电脑环境下可以正常运行。问题可能是由于 PyTorch 库的不同版本造成的,我电脑上的版本是 1.3.1,而 Colab 使用的是 1.4.0 版本。
import torchimport torchvisionfrom torchvision import datasets, transformsimport matplotlib.pyplot as plt transformOpt = transforms.Compose([ transforms.RandomRotation(degrees=(90, -90)), transforms.ToTensor() ]) train_set = datasets.MNIST( root='', train=True, transform=transformOpt, download=True) test_set = datasets.MNIST( root='', train=False, transform=transformOpt, download=True) train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=100, shuffle=True) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=100, shuffle=False) images, labels = next(iter(train_loader)) plt.imshow(images[0].view(28, 28), cmap="gray") plt.show()
这是我在 Google Colab 上执行上述示例代码时得到的完整错误信息。
TypeError Traceback (most recent call last)<ipython-input-1-8409db422154> in <module>() 24 shuffle=False) 25 ---> 26 images, labels = next(iter(train_loader)) 27 plt.imshow(images[0].view(28, 28), cmap="gray") 28 plt.show()10 frames/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self) 343 344 def __next__(self):--> 345 data = self._next_data() 346 self._num_yielded += 1 347 if self._dataset_kind == _DatasetKind.Iterable and \/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self) 383 def _next_data(self): 384 index = self._next_index() # may raise StopIteration--> 385 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 386 if self._pin_memory: 387 data = _utils.pin_memory.pin_memory(data)/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index) 42 def fetch(self, possibly_batched_index): 43 if self.auto_collation:---> 44 data = [self.dataset[idx] for idx in possibly_batched_index] 45 else: 46 data = self.dataset[possibly_batched_index]/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0) 42 def fetch(self, possibly_batched_index): 43 if self.auto_collation:---> 44 data = [self.dataset[idx] for idx in possibly_batched_index] 45 else: 46 data = self.dataset[possibly_batched_index]/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py in __getitem__(self, index) 95 96 if self.transform is not None:---> 97 img = self.transform(img) 98 99 if self.target_transform is not None:/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img) 68 def __call__(self, img): 69 for t in self.transforms:---> 70 img = t(img) 71 return img 72 /usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img) 1001 angle = self.get_params(self.degrees) 1002 -> 1003 return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill) 1004 1005 def__repr__(self):/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in rotate(img, angle, resample, expand, center, fill) 727 fill = tuple([fill] * 3) 728 --> 729 return img.rotate(angle, resample, expand, center, fillcolor=fill) 730 731 /usr/local/lib/python3.6/dist-packages/PIL/Image.py in rotate(self, angle, resample, expand, center, translate, fillcolor) 2003 w, h = nw, nh 2004 -> 2005 return self.transform((w, h), AFFINE, matrix, resample, fillcolor=fillcolor) 2006 2007 def save(self, fp, format=None, **params):/usr/local/lib/python3.6/dist-packages/PIL/Image.py in transform(self, size, method, data, resample, fill, fillcolor) 2297 raise ValueError("missing method data") 2298 -> 2299 im = new(self.mode, size, fillcolor) 2300 if method == MESH: 2301 # list of quads/usr/local/lib/python3.6/dist-packages/PIL/Image.py in new(mode, size, color) 2503 im.palette = ImagePalette.ImagePalette() 2504 color = im.palette.getcolor(color)-> 2505 return im._new(core.fill(mode, size, color)) 2506 2507 TypeError: function takes exactly 1 argument (3 given)
回答:
您完全正确。torchvision 0.5 版本中的 RandomRotation()
在 fill
参数上存在一个错误,可能是因为与 Pillow 版本不兼容。这个问题现已修复(PR#1760),将在下一个版本中解决。
暂时,您可以在 RandomRotation
变换中添加 fill=(0,)
来修复它。
transforms.RandomRotation(degrees=(90, -90), fill=(0,))