在训练网络后,我注意到准确率有起有落。最初我以为这是由学习率引起的,但学习率设置得相当小。请查看附件中的截图。准确率图表截图
我的网络(在Pytorch中)如下所示:
class Network(nn.Module): def __init__(self): super(Network,self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(3,16,kernel_size=3), nn.ReLU(), nn.MaxPool2d(2) ) self.layer2 = nn.Sequential( nn.Conv2d(16,32, kernel_size=3), nn.ReLU(), nn.MaxPool2d(2) ) self.layer3 = nn.Sequential( nn.Conv2d(32,64, kernel_size=3), nn.ReLU(), nn.MaxPool2d(2) ) self.fc1 = nn.Linear(17*17*64,512) self.fc2 = nn.Linear(512,1) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self,x): out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) out = out.view(out.size(0),-1) out = self.relu(self.fc1(out)) out = self.fc2(out) out = torch.sigmoid(out) return out
我使用RMSprop作为优化器,BCELoss作为损失函数。学习率设置为0.001
这是训练过程:
epochs = 15itr = 1p_itr = 100model.train()total_loss = 0loss_list = []acc_list = []for epoch in range(epochs): for samples, labels in train_loader: samples, labels = samples.to(device), labels.to(device) optimizer.zero_grad() output = model(samples) labels = labels.unsqueeze(-1) labels = labels.float() loss = criterion(output, labels) loss.backward() optimizer.step() total_loss += loss.item() scheduler.step() if itr%p_itr == 0: pred = torch.argmax(output, dim=1) correct = pred.eq(labels) acc = torch.mean(correct.float()) print('[Epoch {}/{}] Iteration {} -> Train Loss: {:.4f}, Accuracy: {:.3f}'.format(epoch+1, epochs, itr, total_loss/p_itr, acc)) loss_list.append(total_loss/p_itr) acc_list.append(acc) total_loss = 0 itr += 1
我的数据集相当小 – 2000个训练样本和1000个验证样本(二分类0/1)。我本来想做80/20的分割,但我被要求保持现状。我在想,对于这么小的数据集,架构可能过于复杂了。
有什么建议可以解释训练过程中这种波动的原因吗?
回答:
你的代码这里有误: pred = torch.argmax(output, dim=1)
这一行用于多类分类和交叉熵损失。你的任务是二分类,所以pred
的值是错误的。改为:
if itr%p_itr == 0: pred = torch.round(output) ....
你可以将优化器改为Adam
、SGD
或RMSprop
,以找到帮助模型更快收敛的合适优化器。同时更改forward()
函数:
def forward(self,x): out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) out = out.view(out.size(0),-1) out = self.relu(self.fc1(out)) out = self.fc2(out) return self.sigmoid(out) #使用你的forward函数也可以,但这样更简洁