使用Pytorch时,通过以下代码传递架构类型:
parser.add_argument('-arch', action='store', dest='arch', default= str('vgg16'))
使用以下代码调用架构名称时:
model = models.__dict__['{!r}'.format(results.arch)](pretrained=True)
我得到了以下错误:
model = models.dict‘{!r}’.format(results.arch) KeyError: “‘vgg16′”
我做错了什么?
回答:
你得到了KeyError
,这意味着你导入的models
中不包括'vgg16'
作为已知模型之一。
通过打印以下内容来检查你拥有什么模型:
print(models.__dict__.keys())
这样你就可以知道你导入了哪些模型,哪些模型缺失,然后你可以检查你的导入情况,看看'vgg16'
在哪里丢失了。