PyTorch – 将ProGAN代理从pth转换为onnx

我使用这个PyTorch重实现训练了一个ProGAN代理,并将代理保存为.pth文件。现在我需要将代理转换为.onnx格式,我使用以下脚本进行转换:

from torch.autograd import Variableimport torch.onnximport torchvisionimport torchdevice = torch.device("cuda")dummy_input = torch.randn(1, 3, 64, 64)state_dict = torch.load("GAN_agent.pth", map_location = device)torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")

运行后,我得到了错误AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'(完整提示如下)。据我所知,问题在于将代理转换为.onnx格式需要更多信息。我遗漏了什么吗?

---------------------------------------------------------------------------AttributeError                            Traceback (most recent call last)<ipython-input-2-c64481d4eddd> in <module>     10 state_dict = torch.load("GAN_agent.pth", map_location = device)     11 ---> 12 torch.onnx.export(state_dict, dummy_input, "GAN_agent.onnx")~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)    146                         operator_export_type, opset_version, _retain_param_name,    147                         do_constant_folding, example_outputs,--> 148                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs)    149     150 ~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)     64             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,     65             example_outputs=example_outputs, strip_doc_string=strip_doc_string,---> 66             dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)     67      68 ~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size)    414                                                         example_outputs, propagate,    415                                                         _retain_param_name, do_constant_folding,--> 416                                                         fixed_batch_size=fixed_batch_size)    417     418         # TODO: Don't allocate a in-memory string for the protobuf~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size)    277             model.graph, tuple(in_vars), False, propagate)    278     else:--> 279         graph, torch_out = _trace_and_get_graph_from_model(model, args, training)    280         state_dict = _unique_state_dict(model)    281         params = list(state_dict.values())~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\onnx\utils.py in _trace_and_get_graph_from_model(model, args, training)    226     # A basic sanity check: make sure the state_dict keys are the same    227     # before and after running the model.  Fail fast!--> 228     orig_state_dict_keys = _unique_state_dict(model).keys()    229     230     # By default, training=False, which is good because running a model in~\anaconda3\envs\Basemap_upres\lib\site-packages\torch\jit\__init__.py in _unique_state_dict(module, keep_vars)    283     # id(v) doesn't work with it. So we always get the Parameter or Buffer    284     # as values, and deduplicate the params using Parameters and Buffers--> 285     state_dict = module.state_dict(keep_vars=True)    286     filtered_dict = type(state_dict)()    287     seen_ids = set()AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'

回答:

你拥有的文件是state_dict,它们只是将层名称映射到tensor权重、偏置等(有关更详细的介绍,请参见这里)。

这意味着你需要一个模型,以便这些保存的权重和偏置可以映射到模型上,但首先我们需要做一些准备工作:

1. 模型准备

克隆存储模型定义的仓库,并打开文件/pro_gan_pytorch/pro_gan_pytorch/PRO_GAN.py。为了使其与onnx兼容,我们需要进行一些修改。onnx导出器要求input只能作为torch.tensor传递(或者是那些的list/dict),而Generator类需要intfloat参数。

简单的解决方案是稍微修改forward函数(文件中的第80行,你可以在GitHub上验证)如下所示:

def forward(self, x, depth, alpha):    """    forward pass of the Generator    :param x: input noise    :param depth: current depth from where output is required    :param alpha: value of alpha for fade-in effect    :return: y => output    """    # THOSE TWO LINES WERE ADDED    # We will pas tensors but unpack them here to `int` and `float`    depth = depth.item()    alpha = alpha.item()    # THOSE TWO LINES WERE ADDED    assert depth < self.depth, "Requested output depth cannot be produced"    y = self.initial_block(x)    if depth > 0:        for block in self.layers[: depth - 1]:            y = block(y)        residual = self.rgb_converters[depth - 1](self.temporaryUpsampler(y))        straight = self.rgb_converters[depth](self.layers[depth - 1](y))        out = (alpha * straight) + ((1 - alpha) * residual)    else:        out = self.rgb_converters[0](y)    return out

这里仅添加了通过item()解包。任何不是Tensor类型的输入都应该在函数定义中打包为一个Tensor,并在函数顶部尽快解包。这不会破坏你创建的检查点,所以不用担心,因为它只是layer-weight映射。

2. 模型导出

将此脚本放置在/pro_gan_pytorch(其中README.md也位于此处):

请注意以下几点:

  • 我们必须在加载权重之前创建模型,因为它只是state_dict
  • 需要torch.nn.DataParallel,因为模型是在此上训练的(不确定你的情况,请相应调整)。加载后,我们可以通过module属性获取模块本身。
  • 一切都转换为CPU,我认为这里不需要GPU。不过如果你坚持,也可以转换为GPU
  • 生成器的虚拟输入不能是图像(我使用了存储库作者在他们的Google Drive上提供的文件),它必须是具有512个元素的噪声。

运行它,你的.onnx文件就应该出现了。

哦,由于你想要不同的检查点,你可能需要遵循类似的程序,不过不能保证一切都会正常工作(不过看起来应该是可以的)。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注