我使用这个PyTorch重实现训练了一个ProGAN代理,并将代理保存为.pth
文件。现在我需要将代理转换为.onnx
格式,我使用以下脚本进行转换:
from torch.autograd import Variableimport torch.onnximport torchvisionimport torchdevice = torch.device("cuda")dummy_input = torch.randn(1, 3, 64, 64)state_dict = torch.load("GAN_agent.pth", map_location = device)torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")
运行后,我得到了错误AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'
(完整提示如下)。据我所知,问题在于将代理转换为.onnx格式需要更多信息。我遗漏了什么吗?
---------------------------------------------------------------------------AttributeError Traceback (most recent call last)<ipython-input-2-c64481d4eddd> in <module> 10 state_dict = torch.load("GAN_agent.pth", map_location = device) 11 ---> 12 torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs) 146 operator_export_type, opset_version, _retain_param_name, 147 do_constant_folding, example_outputs,--> 148 strip_doc_string, dynamic_axes, keep_initializers_as_inputs) 149 150 ~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs) 64 _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding, 65 example_outputs=example_outputs, strip_doc_string=strip_doc_string,---> 66 dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs) 67 68 ~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size) 414 example_outputs, propagate, 415 _retain_param_name, do_constant_folding,--> 416 fixed_batch_size=fixed_batch_size) 417 418 # TODO: Don't allocate a in-memory string for the protobuf~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size) 277 model.graph, tuple(in_vars), False, propagate) 278 else:--> 279 graph, torch_out = _trace_and_get_graph_from_model(model, args, training) 280 state_dict = _unique_state_dict(model) 281 params = list(state_dict.values())~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _trace_and_get_graph_from_model(model, args, training) 226 # A basic sanity check: make sure the state_dict keys are the same 227 # before and after running the model. Fail fast!--> 228 orig_state_dict_keys = _unique_state_dict(model).keys() 229 230 # By default, training=False, which is good because running a model in~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\jit\__init__.py in _unique_state_dict(module, keep_vars) 283 # id(v) doesn't work with it. So we always get the Parameter or Buffer 284 # as values, and deduplicate the params using Parameters and Buffers--> 285 state_dict = module.state_dict(keep_vars=True) 286 filtered_dict = type(state_dict)() 287 seen_ids = set()AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'
回答:
你拥有的文件是state_dict
,它们只是将层名称映射到tensor
权重、偏置等(有关更详细的介绍,请参见这里)。
这意味着你需要一个模型,以便这些保存的权重和偏置可以映射到模型上,但首先我们需要做一些准备工作:
1. 模型准备
克隆存储模型定义的仓库,并打开文件/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py
。为了使其与onnx
兼容,我们需要进行一些修改。onnx
导出器要求input
只能作为torch.tensor
传递(或者是那些的list
/dict
),而Generator
类需要int
和float
参数。
简单的解决方案是稍微修改forward
函数(文件中的第80
行,你可以在GitHub上验证)如下所示:
def forward(self, x, depth, alpha): """ forward pass of the Generator :param x: input noise :param depth: current depth from where output is required :param alpha: value of alpha for fade-in effect :return: y => output """ # THOSE TWO LINES WERE ADDED # We will pas tensors but unpack them here to `int` and `float` depth = depth.item() alpha = alpha.item() # THOSE TWO LINES WERE ADDED assert depth < self.depth, "Requested output depth cannot be produced" y = self.initial_block(x) if depth > 0: for block in self.layers[: depth - 1]: y = block(y) residual = self.rgb_converters[depth - 1](self.temporaryUpsampler(y)) straight = self.rgb_converters[depth](self.layers[depth - 1](y)) out = (alpha * straight) + ((1 - alpha) * residual) else: out = self.rgb_converters[0](y) return out
这里仅添加了通过item()
解包。任何不是Tensor
类型的输入都应该在函数定义中打包为一个Tensor
,并在函数顶部尽快解包。这不会破坏你创建的检查点,所以不用担心,因为它只是layer-weight
映射。
2. 模型导出
将此脚本放置在/pro_gan_pytorch
(其中README.md
也位于此处):
请注意以下几点:
- 我们必须在加载权重之前创建模型,因为它只是
state_dict
。 - 需要
torch.nn.DataParallel
,因为模型是在此上训练的(不确定你的情况,请相应调整)。加载后,我们可以通过module
属性获取模块本身。 - 一切都转换为
CPU
,我认为这里不需要GPU
。不过如果你坚持,也可以转换为GPU
。 - 生成器的虚拟输入不能是图像(我使用了存储库作者在他们的Google Drive上提供的文件),它必须是具有
512
个元素的噪声。
运行它,你的.onnx
文件就应该出现了。
哦,由于你想要不同的检查点,你可能需要遵循类似的程序,不过不能保证一切都会正常工作(不过看起来应该是可以的)。