Pytorch transforms.RandomRotation() 在 Google Colab 上无法工作

通常我在自己的电脑上进行字母和数字识别工作,现在我想将项目移到 Colab 上,但不幸的是遇到了错误(您可以在下面看到错误)。经过一些调试后,我找到了导致错误的那一行代码。

transforms.RandomRotation(degrees=(90, -90))

我在下面写了一个简单的抽象代码来展示这个错误。这个代码在 Colab 上无法工作,但在我的电脑环境下可以正常运行。问题可能是由于 PyTorch 库的不同版本造成的,我电脑上的版本是 1.3.1,而 Colab 使用的是 1.4.0 版本。

import torchimport torchvisionfrom torchvision import datasets, transformsimport matplotlib.pyplot as plt       transformOpt = transforms.Compose([            transforms.RandomRotation(degrees=(90, -90)),            transforms.ToTensor()        ])    train_set = datasets.MNIST(        root='', train=True, transform=transformOpt, download=True)    test_set = datasets.MNIST(        root='', train=False, transform=transformOpt, download=True)    train_loader = torch.utils.data.DataLoader(        dataset=train_set,        batch_size=100,        shuffle=True)    test_loader = torch.utils.data.DataLoader(        dataset=test_set,        batch_size=100,        shuffle=False)    images, labels = next(iter(train_loader))    plt.imshow(images[0].view(28, 28), cmap="gray")    plt.show()

这是我在 Google Colab 上执行上述示例代码时得到的完整错误信息。

TypeError                                 Traceback (most recent call last)<ipython-input-1-8409db422154> in <module>()     24     shuffle=False)     25 ---> 26 images, labels = next(iter(train_loader))     27 plt.imshow(images[0].view(28, 28), cmap="gray")     28 plt.show()10 frames/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)    343     344     def __next__(self):--> 345         data = self._next_data()    346         self._num_yielded += 1    347         if self._dataset_kind == _DatasetKind.Iterable and \/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)    383     def _next_data(self):    384         index = self._next_index()  # may raise StopIteration--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration    386         if self._pin_memory:    387             data = _utils.pin_memory.pin_memory(data)/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)     42     def fetch(self, possibly_batched_index):     43         if self.auto_collation:---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]     45         else:     46             data = self.dataset[possibly_batched_index]/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)     42     def fetch(self, possibly_batched_index):     43         if self.auto_collation:---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]     45         else:     46             data = self.dataset[possibly_batched_index]/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py in __getitem__(self, index)     95      96         if self.transform is not None:---> 97             img = self.transform(img)     98      99         if self.target_transform is not None:/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)     68     def __call__(self, img):     69         for t in self.transforms:---> 70             img = t(img)     71         return img     72 /usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)    1001         angle = self.get_params(self.degrees)    1002 -> 1003         return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)    1004     1005     def__repr__(self):/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in rotate(img, angle, resample, expand, center, fill)    727         fill = tuple([fill] * 3)    728 --> 729     return img.rotate(angle, resample, expand, center, fillcolor=fill)    730     731 /usr/local/lib/python3.6/dist-packages/PIL/Image.py in rotate(self, angle, resample, expand, center, translate, fillcolor)    2003         w, h = nw, nh    2004 -> 2005         return self.transform((w, h), AFFINE, matrix, resample, fillcolor=fillcolor)    2006     2007     def save(self,    fp, format=None, **params):/usr/local/lib/python3.6/dist-packages/PIL/Image.py in transform(self, size, method, data, resample, fill, fillcolor)    2297             raise ValueError("missing method data")    2298 -> 2299         im = new(self.mode, size, fillcolor)    2300         if method == MESH:    2301             # list of quads/usr/local/lib/python3.6/dist-packages/PIL/Image.py in new(mode, size, color)    2503         im.palette = ImagePalette.ImagePalette()    2504         color = im.palette.getcolor(color)-> 2505     return im._new(core.fill(mode, size, color))    2506     2507 TypeError: function takes exactly 1 argument (3 given)

回答:

您完全正确。torchvision 0.5 版本中的 RandomRotation()fill 参数上存在一个错误,可能是因为与 Pillow 版本不兼容。这个问题现已修复(PR#1760),将在下一个版本中解决。

暂时,您可以在 RandomRotation 变换中添加 fill=(0,) 来修复它。

transforms.RandomRotation(degrees=(90, -90), fill=(0,))

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注