移除最大池化层导致PyTorch中CUDA内存溢出错误

显卡:GTX1070Ti 8GB,批次大小64,输入图像尺寸128*128。我之前使用ResNet152作为编码器的UNET运行效果很好:

class UNetResNet(nn.Module): def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,             pretrained=False, is_deconv=False):    super().__init__()    self.num_classes = num_classes    self.dropout_2d = dropout_2d    if encoder_depth == 34:        self.encoder = torchvision.models.resnet34(pretrained=pretrained)        bottom_channel_nr = 512    elif encoder_depth == 101:        self.encoder = torchvision.models.resnet101(pretrained=pretrained)        bottom_channel_nr = 2048    elif encoder_depth == 152:        self.encoder = torchvision.models.resnet152(pretrained=pretrained)        bottom_channel_nr = 2048         else:        raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')    self.pool = nn.MaxPool2d(2, 2)    self.relu = nn.ReLU(inplace=True)    self.conv1 = nn.Sequential(self.encoder.conv1,                               self.encoder.bn1,                               self.encoder.relu,                               self.pool) #我想移除这个池化层    self.conv2 = self.encoder.layer1    self.conv3 = self.encoder.layer2    self.conv4 = self.encoder.layer3    self.conv5 = self.encoder.layer4    self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)        self.dec5 =  DecoderBlockV(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)    self.dec4 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)    self.dec3 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)    self.dec2 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,                               is_deconv)    self.dec1 = DecoderBlockV(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)    self.dec0 = ConvRelu(num_filters, num_filters)    self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)def forward(self, x):    conv1 = self.conv1(x)    conv2 = self.conv2(conv1)    conv3 = self.conv3(conv2)    conv4 = self.conv4(conv3)    conv5 = self.conv5(conv4)     center = self.center(conv5)    dec5 = self.dec5(torch.cat([center, conv5], 1))    dec4 = self.dec4(torch.cat([dec5, conv4], 1))    dec3 = self.dec3(torch.cat([dec4, conv3], 1))    dec2 = self.dec2(torch.cat([dec3, conv2], 1))    dec1 = self.dec1(dec2)    dec0 = self.dec0(dec1)    return self.final(F.dropout2d(dec0, p=self.dropout_2d))# blocks    class DecoderBlockV(nn.Module):        def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):            super(DecoderBlockV2, self).__init__()            self.in_channels = in_channels                if is_deconv:                self.block = nn.Sequential(                    ConvRelu(in_channels, middle_channels),                    nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,                                       padding=1),                    nn.BatchNorm2d(out_channels),                    nn.ReLU(inplace=True)                                    )            else:                self.block = nn.Sequential(                    nn.Upsample(scale_factor=2, mode='bilinear'),                    ConvRelu(in_channels, middle_channels),                    ConvRelu(middle_channels, out_channels),                )            def forward(self, x):            return self.block(x)class DecoderCenter(nn.Module):    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):        super(DecoderCenter, self).__init__()        self.in_channels = in_channels                if is_deconv:            """                Paramaters for Deconvolution were chosen to avoid artifacts, following                link https://distill.pub/2016/deconv-checkerboard/            """            self.block = nn.Sequential(                ConvRelu(in_channels, middle_channels),                nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,                                   padding=1),        nn.BatchNorm2d(out_channels),                 nn.ReLU(inplace=True)            )        else:            self.block = nn.Sequential(                ConvRelu(in_channels, middle_channels),                ConvRelu(middle_channels, out_channels)             )    def forward(self, x):        return self.block(x)

然后我编辑了类,使其在没有池化层的情况下工作:

class UNetResNet(nn.Module):    def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,                 pretrained=False, is_deconv=False):        super().__init__()        self.num_classes = num_classes        self.dropout_2d = dropout_2d        if encoder_depth == 34:            self.encoder = torchvision.models.resnet34(pretrained=pretrained)            bottom_channel_nr = 512        elif encoder_depth == 101:            self.encoder = torchvision.models.resnet101(pretrained=pretrained)            bottom_channel_nr = 2048        elif encoder_depth == 152:            self.encoder = torchvision.models.resnet152(pretrained=pretrained)            bottom_channel_nr = 2048        else:            raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')        self.relu = nn.ReLU(inplace=True)        self.input_adjust = nn.Sequential(self.encoder.conv1,                                          self.encoder.bn1,                                          self.encoder.relu)        self.conv1 = self.encoder.layer1        self.conv2 = self.encoder.layer2        self.conv3 = self.encoder.layer3        self.conv4 = self.encoder.layer4        self.dec4 = DecoderBlockV(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)        self.dec3 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8,    is_deconv)        self.dec2 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2,    is_deconv)        self.dec1 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,is_deconv)        self.final = nn.Conv2d(num_filters * 2 * 2, num_classes, kernel_size=1)    def forward(self, x):        input_adjust = self.input_adjust(x)        conv1 = self.conv1(input_adjust)        conv2 = self.conv2(conv1)        conv3 = self.conv3(conv2)        center = self.conv4(conv3)        dec4 = self.dec4(center) #现在没有中心块        dec3 = self.dec3(torch.cat([dec4, conv3], 1))        dec2 = self.dec2(torch.cat([dec3, conv2], 1))        dec1 = F.dropout2d(self.dec1(torch.cat([dec2, conv1], 1)), p=self.dropout_2d)        return self.final(dec1)

is_deconv – 在两种情况下都是True。更改后,它无法在批次大小为64的情况下工作,只能在批次大小为16的情况下工作,或者在批次大小为64但仅使用ResNet16的情况下工作 – 否则会出现CUDA内存不足的情况。我做错了什么?

完整的错误堆栈:

~/Desktop/ml/salt/open-solution-salt-identification-master/common_blocks/unet_models.py in forward(self, x)    418         conv1 = self.conv1(input_adjust)    419         conv2 = self.conv2(conv1)--> 420         conv3 = self.conv3(conv2)    421         center = self.conv4(conv3)    422         dec4 = self.dec4(center)~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)    355             result = self._slow_forward(*input, **kwargs)    356         else:--> 357             result = self.forward(*input, **kwargs)    358         for hook in self._forward_hooks.values():    359             hook_result = hook(self, input, result)~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)     65     def forward(self, input):     66         for module in self._modules.values():---> 67             input = module(input)     68         return input     69 ~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)    355             result = self._slow_forward(*input, **kwargs)    356         else:--> 357             result = self.forward(*input, **kwargs)    358         for hook in self._forward_hooks.values():    359             hook_result = hook(self, input, result)~/anaconda3/lib/python3.6/site-packages/torchvision-0.2.0-py3.6.egg/torchvision/models/resnet.py in forward(self, x)     79      80         out = self.conv2(out)---> 81         out = self.bn2(out)     82         out = self.relu(out)

回答:

Related Posts

Flatten and back keras

我正在尝试使用自编码器获取简单向量中的值 这是我的代码…

如何按索引访问PyTorch模型参数

如果我的网络有10层,包括偏置项,如何仅通过索引访问第…

Python中多元逻辑回归显示错误

我试图使用逻辑回归进行预测,并使用Python和skl…

在MACOS上安装NLTK

我在我的2015款Mac Pro上尝试安装NLTK,操…

如何在R中将通过RFE选择的变量插入到机器学习模型中?

我想使用递归特征消除方法来选择最重要的特征,然后将这些…

CountVectorizer 错误:ValueError: setting an array element with a sequence

我有一个包含144条学生反馈的数据集,其中有72条正面…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注