当我从 torchvision.models
获取一个预训练模型后,我想对所有的 ReLU
实例应用 register_backward_hook(f)
,就像这样:
for pos, module in self.model.features._modules.items(): for sub_module in module: if isinstance(module, ReLU): module.register_backward_hook(f)
对我来说的问题是如何在一个模型中找到所有的 ReLU
。对于 densenet161
来说,ReLU
不仅存在于 model.features._modules
中,还存在于自定义的密集层中,例如 model.features._modules['denseblock1'][0]
。对于 resnet151
来说,ReLU
存在于 model._modules
及其自定义层中,例如 model._modules['layer1']
。
有什么方法可以找到模型内部的所有 ReLU
吗?
回答:
遍历模型所有组件的更优雅方法是使用 modules()
方法:
from torch import nnfor module in self.model.modules(): if isinstance(module, nn.ReLU): module.register_backward_hook(f)
如果你不想获取所有子模块,只想要直接的子模块,你可以考虑使用 children()
方法来代替 modules()
。你还可以使用 named_modules()
方法来获取子模块的名称。