我正在尝试在pytorch
中实现一个全连接网络(FCN),其整体结构如下:
目前的代码如下所示:
class SNet(nn.Module): def __init__(self): super(SNet, self).__init__() self.enc_a = encoder(...) self.dec_a = decoder(...) self.enc_b = encoder(...) self.dec_b = decoder(...) def forward(self, x1, x2): x1 = self.enc_a(x1) x2 = self.enc_b(x2) x2 = self.dec_b(x2) x1 = self.dec_a(torch.cat((x1, x2), dim=-1) return x1, x2
在keras
中,使用函数式API相对容易实现这一点。然而,我找不到任何具体的例子或教程来在pytorch
中实现这一点。
- 如何在训练后丢弃
dec_a
(自编码器分支的解码器部分)? - 在联合训练期间,
loss
将是来自两个分支的loss
的总和(可选加权)吗?
回答:
你也可以为你的模型定义单独的训练和推理模式:
class SNet(nn.Module): def __init__(self): super(SNet, self).__init__() self.enc_a = encoder(...) self.dec_a = decoder(...) self.enc_b = encoder(...) self.dec_b = decoder(...) self.training = True def forward(self, x1, x2): if self.training: x1 = self.enc_a(x1) x2 = self.enc_b(x2) x2 = self.dec_b(x2) x1 = self.dec_a(torch.cat((x1, x2), dim=-1) return x1, x2 else: x1 = self.enc_a(x1) x2 = self.enc_b(x2) x2 = self.dec_b(x2) return x2
这些代码块是示例,可能会与你想要的有所不同,因为我认为在你定义的训练和推理操作的块图与代码之间存在一些歧义,但无论如何,你应该能理解如何仅在训练模式下使用某些模块。然后你可以相应地设置这个变量。