如何为PyTorch层赋名?

之前的问题之后,我想绘制权重、偏置、激活和梯度,以实现与这个类似的结果。

使用

for name, param in model.named_parameters():    summary_writer.add_histogram(f'{name}.grad', param.grad, step_index)

正如在之前的问题中所建议的,这种方法的效果并不理想,因为层名称类似于'_decoder._decoder.4.weight',难以理解,尤其是由于研究原因,架构在不断变化。在这个运行中4在下次运行中不会相同,并且实际上毫无意义。

因此,我想为每个层赋予自己的字符串名称。


我找到了这个 PyTorch论坛讨论,但没有达成一致的最佳实践。

为PyTorch层赋名的推荐方法是什么?

即,以各种方式定义的层:

  1. 顺序型:
self._seq = nn.Sequential(nn.Linear(1, 2), nn.Linear(3, 4),)
  1. 动态型:
self._dynamic = nn.ModuleList()    for _ in range(self._n_features):         self._last_layer.append(nn.Conv1d(in_channels=5, out_channels=6, kernel_size=3, stride=1, padding=1,),)
  1. 直接型:
self._direct = nn.Linear(7, 8)
  1. 其他我没有想到的方式

我希望能够为以上述每种方式定义的每个层赋予一个字符串名称。


回答:

顺序型

传递collections.OrderedDict的实例。下面的代码会生成conv1.weightsconv1.biasconv2.weightconv2.bias(注意没有torch.nn.ReLU(),请参见本答案的结尾)。

import collectionsimport torchmodel = torch.nn.Sequential(    collections.OrderedDict(        [            ("conv1", torch.nn.Conv2d(1, 20, 5)),            ("relu1", torch.nn.ReLU()),            ("conv2", torch.nn.Conv2d(20, 64, 5)),            ("relu2", torch.nn.ReLU()),        ]    ))for name, param in model.named_parameters():    print(name)

动态型

使用ModuleDict代替ModuleList

class MyModule(torch.nn.Module):    def __init__(self):        super().__init__()        self.whatever = torch.nn.ModuleDict(            {f"my_name{i}": torch.nn.Conv2d(10, 10, 3) for i in range(5)}        )

这将为每个动态创建的模块提供whatever.my_name{i}.weight(或bias)。

直接型

只需按你想要的方式命名即可,它将按此命名

self.my_name_or_whatever = nn.Linear(7, 8)

你没有想到的

  • 如果你想绘制权重、偏置及其梯度,可以按照这条路线进行
  • 你无法通过这种方式绘制激活(或激活的输出)。请改用PyTorch钩子(如果你想在网络通过时获得每层的梯度,也可以使用这个)

对于最后一个任务,你可以使用第三方库torchfunc(免责声明:我是作者)或者直接编写你自己的钩子。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注