无法从Pytorch-Lightning的检查点加载模型

我在使用Pytorch Lightning进行U-Net模型的训练。模型训练成功,但在训练后尝试从检查点加载模型时,遇到了以下错误:

完整的错误追踪信息如下:

Traceback (most recent call last):  File "src/train.py", line 269, in <module>    main(sys.argv[1:])  File "src/train.py", line 263, in main    model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint    model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs)  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 190, in _load_model_state    model = cls(*cls_args, **cls_kwargs)  File "src/train.py", line 162, in __init__    self.inc = double_conv(self.n_channels, 64)  File "src/train.py", line 122, in double_conv    nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 406, in __init__    super(Conv2d, self).__init__(  File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 50, in __init__    if in_channels % groups != 0:TypeError: unsupported operand type(s) for %: 'dict' and 'int'

我尝试查看了GitHub的问题和论坛,但无法找出问题所在。请帮助我解决这个问题。

以下是我的模型代码和加载检查点的步骤:
模型:

class Unet(pl.LightningModule):    def __init__(self, n_channels, n_classes=5):        super(Unet, self).__init__()        # self.hparams = hparams        self.n_channels = n_channels        self.n_classes = n_classes        self.bilinear = True        self.logger = WandbLogger(name="Adam", project="pytorchlightning")        def double_conv(in_channels, out_channels):            return nn.Sequential(                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),                nn.BatchNorm2d(out_channels),                nn.ReLU(inplace=True),                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),                nn.BatchNorm2d(out_channels),                nn.ReLU(inplace=True),            )        def down(in_channels, out_channels):            return nn.Sequential(                nn.MaxPool2d(2), double_conv(in_channels, out_channels)            )        class up(nn.Module):            def __init__(self, in_channels, out_channels, bilinear=False):                super().__init__()                if bilinear:                    self.up = nn.Upsample(                        scale_factor=2, mode="bilinear", align_corners=True                    )                else:                    self.up = nn.ConvTranspose2d(                        in_channels // 2, in_channels // 2, kernel_size=2, stride=2                    )                self.conv = double_conv(in_channels, out_channels)            def forward(self, x1, x2):                x1 = self.up(x1)                # [?, C, H, W]                diffY = x2.size()[2] - x1.size()[2]                diffX = x2.size()[3] - x1.size()[3]                x1 = F.pad(                    x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]                )                x = torch.cat([x2, x1], dim=1)                return self.conv(x)        self.inc = double_conv(self.n_channels, 64)        self.down1 = down(64, 128)        self.down2 = down(128, 256)        self.down3 = down(256, 512)        self.down4 = down(512, 512)        self.up1 = up(1024, 256)        self.up2 = up(512, 128)        self.up3 = up(256, 64)        self.up4 = up(128, 64)        self.out = nn.Conv2d(64, self.n_classes, kernel_size=1)    def forward(self, x):        x1 = self.inc(x)        x2 = self.down1(x1)        x3 = self.down2(x2)        x4 = self.down3(x3)        x5 = self.down4(x4)        x = self.up1(x5, x4)        x = self.up2(x, x3)        x = self.up3(x, x2)        x = self.up4(x, x1)        return self.out(x)    def training_step(self, batch, batch_nb):        x, y = batch        y_hat = self.forward(x)        loss = self.MSE(y_hat, y)        # wandb_logger.log_metrics({"loss":loss})        return {"loss": loss}    def training_epoch_end(self, outputs):        avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()        self.logger.log_metrics({"train_loss": avg_train_loss})        return {"average_loss": avg_train_loss}    def test_step(self, batch, batch_nb):        x, y = batch        y_hat = self.forward(x)        loss = self.MSE(y_hat, y)        return {"test_loss": loss, "pred": y_hat}    def test_end(self, outputs):        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()        return {"avg_test_loss": avg_loss}    def MSE(self, logits, labels):        return torch.mean((logits - labels) ** 2)    def configure_optimizers(self):        return torch.optim.Adam(self.parameters(), lr=0.1, weight_decay=1e-8)

主函数:

def main(expconfig):    # 定义检查点回调    checkpoint_callback = ModelCheckpoint(        filepath="/home/africa_wikilimo/data/model_checkpoint/",        save_top_k=1,        verbose=True,        monitor="loss",        mode="min",        prefix="",    )    # 初始化数据集    print("正在初始化气候数据集....")    clima_train = Clima_Dataset(expconfig[0])    # 初始化数据加载器    print("正在初始化训练加载器....")    train_dataloader = DataLoader(clima_train, batch_size=2, num_workers=4)    # 初始化模型和训练器    print("正在初始化模型...")    model = Unet(n_channels=9, n_classes=5)    print("正在初始化训练器....")    if torch.cuda.is_available():        model.cuda()        trainer = pl.Trainer(            max_epochs=1,            gpus=1,            checkpoint_callback=checkpoint_callback,            early_stop_callback=None,        )    else:        trainer = pl.Trainer(max_epochs=1, checkpoint_callback=checkpoint_callback)        trainer.fit(model, train_dataloader=train_dataloader)    print(checkpoint_callback.best_model_path)    model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)

回答:

原因

出现这个问题是因为您的模型无法从检查点中加载超参数(n_channels, n_classes=5),因为您没有明确保存这些参数。

解决方法

您可以通过在Unet类的init方法中使用self.save_hyperparameters('n_channels', 'n_classes')方法来解决这个问题。有关此方法的更多详细信息,请参考PyTorch Lightning超参数文档。使用save_hyperparameters方法可以将选定的参数保存到hparams.yaml文件中,并与检查点一起保存。

感谢PyTorch Lightning核心贡献者团队的@Adrian Wälchli(awaelchli)提出了这个解决方案,当我遇到相同问题时,他提供了帮助。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注