使用 nn.ModuleDict 编码的训练网络

我被分配编写一个使用 nn.ModuleDict 的简单网络。以下是我的实现:

third_model = torch.nn.ModuleDict({'flatten': torch.nn.Flatten(),'fc1': torch.nn.Linear(32 * 32 * 3, 1024),'relu': torch.nn.ReLU(),'fc2': torch.nn.Linear(1024, 240),'relu': torch.nn.ReLU(),'fc3': torch.nn.Linear(240, 10)})

然后我尝试使用 cuda 训练它:

third_model.to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(third_model.parameters(), lr=0.001, momentum=0.9)train(third_model, criterion, optimizer, train_dataloader, test_dataloader)

函数 “train(model, criterion, optimizer, train_dataloader, test_dataloader)” 用于训练模型并可视化模型的损失和准确率。它的工作正常。

训练函数如下:

def train(model, criterion, optimizer, train_dataloader, test_dataloader):    train_loss_log = []    train_acc_log = []    val_loss_log = []    val_acc_log = []        for epoch in range(NUM_EPOCH):        model.train()        train_loss = 0.        train_size = 0        train_acc = 0.        for imgs, labels in train_dataloader:            imgs, labels = imgs.to(device), labels.to(device)                        optimizer.zero_grad()            y_pred = model(imgs)            loss = criterion(y_pred, labels)            loss.backward()            optimizer.step()                        train_loss += loss.item()            train_size += y_pred.size(0)            train_loss_log.append(loss.data / y_pred.size(0))            _, pred_classes = torch.max(y_pred, 1)            train_acc += (pred_classes == labels).sum().item()            train_acc_log.append(np.mean((pred_classes == labels).cpu().numpy()))        val_loss = 0.        val_size = 0        val_acc = 0.        model.eval()        with torch.no_grad():            for imgs, labels in test_dataloader:                imgs, labels = imgs.to(device), labels.to(device)                pred = model(imgs)                loss = criterion(pred, labels)                val_loss += loss.item()                val_size += pred.size(0)                _, pred_classes = torch.max(pred, 1)                val_acc += (pred_classes == labels).sum().item()        val_loss_log.append(val_loss / val_size)        val_acc_log.append(val_acc / val_size)        clear_output()        plot_history(train_loss_log, val_loss_log, 'loss')        plot_history(train_acc_log, val_acc_log, 'accuracy')        print('训练损失:', train_loss / train_size)        print('训练准确率:', train_acc / train_size)        print('验证损失:', val_loss / val_size)        print('验证准确率:', val_acc / val_size)

我已经训练了使用 nn.Sequential 编码的模型,一切正常。然而,使用 nn.ModuleDict 时,我遇到了一个错误:

TypeError                                 Traceback (most recent call last)<ipython-input-144-8b33ad3aad2c> in <module>()      2 optimizer = optim.SGD(third_model.parameters(), lr=0.001, momentum=0.9)      3 ----> 4 train(third_model, criterion, optimizer, train_dataloader, test_dataloader)1 frames/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)    720             result = self._slow_forward(*input, **kwargs)    721         else:--> 722             result = self.forward(*input, **kwargs)    723         for hook in itertools.chain(    724                 _global_forward_hooks.values(),TypeError: forward() takes 1 positional argument but 2 were given

我尝试查找有关 nn.ModuleDict 的任何文档,但似乎没有使用它编码网络的例子。

问题似乎可能出在线性层上,尽管我不知道为什么。

所以,我希望有人能解释一下错误出在哪里。将非常感谢任何可能的建议。


回答:

nn.moduleDict 是一个容器,它的 forward 函数没有定义。它应该用于存储子模块/网络。

你应该使用以有序字典 OrderedDict 初始化的 nn.Sequential

third_model = torch.nn.Sequential(    OrderedDict([       ('flatten', torch.nn.Flatten()),       ('fc1', torch.nn.Linear(32 * 32 * 3, 1024)),       ('relu', torch.nn.ReLU()),       ('fc2', torch.nn.Linear(1024, 240)),       ('relu', torch.nn.ReLU()),       ('fc3', torch.nn.Linear(240, 10))]))

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注