我的PyTorch forward函数可以执行额外的操作吗?

通常情况下,forward函数会将多个层连接起来,并返回最后一个层的输出。在返回之前,我可以在最后一个层之后进行一些额外的处理吗?例如,进行标量乘法和通过.view进行重塑?

我知道自动梯度计算会自动处理梯度。所以我不确定我的额外处理是否会干扰这个过程。谢谢。


回答:

通过计算图跟踪张量的梯度,而不是通过函数。只要你的张量具有requires_grad=True属性,并且它们的grad不为None,你就可以(几乎)随意操作并仍然能够进行反向传播。
只要你使用的是PyTorch的操作(例如,列在这里的这里这里),你应该没问题。

更多信息请参见这里

例如(摘自torchvision的VGG实现):

class VGG(nn.Module):    def __init__(self, features, num_classes=1000, init_weights=True):        super(VGG, self).__init__()        #  ...    def forward(self, x):        x = self.features(x)        x = self.avgpool(x)        x = torch.flatten(x, 1)  # <-- 你所询问的部分        x = self.classifier(x)        return x

更复杂的例子可以在torchvision的ResNet实现中看到:

class Bottleneck(nn.Module):    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,                 base_width=64, dilation=1, norm_layer=None):        super(Bottleneck, self).__init__()        # ...    def forward(self, x):        identity = x        out = self.conv1(x)        out = self.bn1(out)        out = self.relu(out)        out = self.conv2(out)        out = self.bn2(out)        out = self.relu(out)        out = self.conv3(out)        out = self.bn3(out)        if self.downsample is not None:    # <-- 条件执行!            identity = self.downsample(x)        out += identity  # <-- 就地操作        out = self.relu(out)        return out

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注