我想了解为什么在定义像GAN这样的神经网络类时,我们要将torch.nn.Module作为参数传递。
import torchimport torch.nn as nnclass Generator(nn.Module): def __init__(self, input_size, hidden_size, output_size, f): super(Generator, self).__init__() self.map1 = nn.Linear(input_size, hidden_size) self.map2 = nn.Linear(hidden_size, hidden_size) self.map3 = nn.Linear(hidden_size, output_size) self.f = f
回答:
这一行
class Generator(nn.Module):
简单来说,意味着Generator
类将继承nn.Module
类,这不是一个参数。
然而,特殊的__init__
方法:
def __init__(self, input_size, hidden_size, output_size, f):
包含了self
,这可能使你认为它是一个参数。
但这实际上是Python类实例的self
。关于self
是否应该保留曾有过争论,但Guido在他的博客中解释了为什么它必须保留
,见这里
。