我在尝试使用Pytorch中的AdaptiveAvgPool3D时遇到了这个错误。以下是错误追踪
Traceback (most recent call last):
文件 “/scratch/a.bip5/BraTS 2021/./sisa.py”, 第395行,outputs = model(inputs)文件 “/home/a.bip5/.conda/envs/pix2pix/lib/python3.9/site-packages/torch/nn/modules/module.py”, 第1051行,_call_impl返回 forward_call(*input, **kwargs)文件 “/home/a.bip5/.conda/envs/pix2pix/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py”, 第166行,forward返回 self.module(*inputs[0], **kwargs[0])文件 “/home/a.bip5/.conda/envs/pix2pix/lib/python3.9/site-packages/torch/nn/modules/module.py”, 第1051行,_call_impl返回 forward_call(*input, **kwargs)文件 “/scratch/a.bip5/BraTS 2021/./sisa.py”, 第96行,forwardx1 = self.pool1(x)文件 “/home/a.bip5/.conda/envs/pix2pix/lib/python3.9/site-packages/torch/nn/modules/module.py”, 第1051行,_call_impl返回 forward_call(*input, **kwargs)文件 “/scratch/a.bip5/BraTS 2021/./sisa.py”, 第135行,forwardx1=aa(x)文件 “/home/a.bip5/.conda/envs/pix2pix/lib/python3.9/site-packages/torch/nn/modules/module.py”, 第1051行,_call_impl返回 forward_call(*input, **kwargs)文件 “/home/a.bip5/.conda/envs/pix2pix/lib/python3.9/site-packages/torch/nn/modules/pooling.py”, 第1166行,forward返回 F.adaptive_avg_pool3d(input, self.output_size)文件 “/home/a.bip5/.conda/envs/pix2pix/lib/python3.9/site-packages/torch/nn/functional.py”, 第1148行,adaptive_avg_pool3d返回 torch._C._nn.adaptive_avg_pool3d(input, _output_size)TypeError: adaptive_avg_pool3d(): 参数’output_size’(位置2)必须是整数元组,而不能是列表
在查看错误堆栈时,我在../functional.py中发现了这段代码:
if has_torch_function_unary(input): return handle_torch_function( adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices )_output_size = _list_with_default(output_size, input.size())return torch._C._nn.adaptive_max_pool3d(input, _output_size)
打印output_size和_output_size的类型显示其中一个是元组(如预期的那样),而另一个在传递给库函数之前被转换成了列表。我不明白的是,为什么如果使用这个列表的函数不喜欢它,为什么还要进行这种转换?如果这个转换的目的是为了引发错误,那么满足has_torch_function_unary
条件的条件是什么?
任何帮助都将不胜感激。
编辑:
我尝试通过在最终返回语句中使用output_size
而不是_output_size
来绕过这个问题。这导致了一个更加神秘的错误-
TypeError: adaptive_avg_pool3d(): 参数’output_size’(位置2)必须是整数元组,而不能是元组
Pytorch如何区分整数元组和元组?
回答:
错误的原因原来是类型错误。在将大小传递给AdaptiveAvgPool3d时,我通过除法获取了大小的值。尽管余数为0,但在python中除法默认将值保存为浮点数。这意味着作为output_size
传递的元组实际上是浮点数的元组,_list_with_default将其转换为列表,但如果元组由整数组成,则可能不会转换。简单地对传递给AdaptiveAvgPool3d的每个维度使用int(..)就是解决方案。