如何为PyTorch层赋名?

之前的问题之后,我想绘制权重、偏置、激活和梯度,以实现与这个类似的结果。

使用

for name, param in model.named_parameters():    summary_writer.add_histogram(f'{name}.grad', param.grad, step_index)

正如在之前的问题中所建议的,这种方法的效果并不理想,因为层名称类似于'_decoder._decoder.4.weight',难以理解,尤其是由于研究原因,架构在不断变化。在这个运行中4在下次运行中不会相同,并且实际上毫无意义。

因此,我想为每个层赋予自己的字符串名称。


我找到了这个 PyTorch论坛讨论,但没有达成一致的最佳实践。

为PyTorch层赋名的推荐方法是什么?

即,以各种方式定义的层:

  1. 顺序型:
self._seq = nn.Sequential(nn.Linear(1, 2), nn.Linear(3, 4),)
  1. 动态型:
self._dynamic = nn.ModuleList()    for _ in range(self._n_features):         self._last_layer.append(nn.Conv1d(in_channels=5, out_channels=6, kernel_size=3, stride=1, padding=1,),)
  1. 直接型:
self._direct = nn.Linear(7, 8)
  1. 其他我没有想到的方式

我希望能够为以上述每种方式定义的每个层赋予一个字符串名称。


回答:

顺序型

传递collections.OrderedDict的实例。下面的代码会生成conv1.weightsconv1.biasconv2.weightconv2.bias(注意没有torch.nn.ReLU(),请参见本答案的结尾)。

import collectionsimport torchmodel = torch.nn.Sequential(    collections.OrderedDict(        [            ("conv1", torch.nn.Conv2d(1, 20, 5)),            ("relu1", torch.nn.ReLU()),            ("conv2", torch.nn.Conv2d(20, 64, 5)),            ("relu2", torch.nn.ReLU()),        ]    ))for name, param in model.named_parameters():    print(name)

动态型

使用ModuleDict代替ModuleList

class MyModule(torch.nn.Module):    def __init__(self):        super().__init__()        self.whatever = torch.nn.ModuleDict(            {f"my_name{i}": torch.nn.Conv2d(10, 10, 3) for i in range(5)}        )

这将为每个动态创建的模块提供whatever.my_name{i}.weight(或bias)。

直接型

只需按你想要的方式命名即可,它将按此命名

self.my_name_or_whatever = nn.Linear(7, 8)

你没有想到的

  • 如果你想绘制权重、偏置及其梯度,可以按照这条路线进行
  • 你无法通过这种方式绘制激活(或激活的输出)。请改用PyTorch钩子(如果你想在网络通过时获得每层的梯度,也可以使用这个)

对于最后一个任务,你可以使用第三方库torchfunc(免责声明:我是作者)或者直接编写你自己的钩子。

Related Posts

如何对SVC进行超参数调优?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

如何在初始训练后向模型添加训练数据?

我想在我的scikit-learn模型已经训练完成后再…

使用Google Cloud Function并行运行带有不同用户参数的相同训练作业

我正在寻找一种方法来并行运行带有不同用户参数的相同训练…

加载Keras模型,TypeError: ‘module’ object is not callable

我已经在StackOverflow上搜索并阅读了文档,…

在计算KNN填补方法中特定列中NaN值的”距离平均值”时

当我从头开始实现KNN填补方法来处理缺失数据时,我遇到…

使用巨大的S3 CSV文件或直接从预处理的关系型或NoSQL数据库获取数据的机器学习训练/测试工作

已关闭。此问题需要更多细节或更清晰的说明。目前不接受回…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注