评估每轮训练后神经网络的准确性

from dataset import get_strange_symbol_loader, get_strange_symbols_test_dataimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimclass Net(nn.Module):   def __init__(self):    super().__init__()    self.fc1 = nn.Linear(28*28, 512)    self.fc2 = nn.Linear(512, 256)    self.fc3 = nn.Linear(256, 15)def forward(self,x):    x = F.relu(self.fc1(x))    x = F.relu(self.fc2(x))    x = self.fc3(x)    return F.softmax(x, dim=1)if __name__ == '__main__':   net = Net()   train, test = get_strange_symbol_loader(batch_size=128)   loss_function = nn.CrossEntropyLoss()   optimizer = optim.Adam(net.parameters(), lr=1e-3)   Accuracy = []   for epoch in range(30):          print("epoch",epoch)       #Train       for data in train:           img, label = data             net.zero_grad()           output = net(img.view(-1,28*28))           loss = F.nll_loss(output, label)           loss.backward()           optimizer.step()       #Test           correct, total = 0, 0       with torch.no_grad():          for data in test:               img, label = data               output = net(img.view(-1,784))               for idx, i in enumerate(output):                   if torch.argmax(i) == label[idx]:                       correct += 1                       total += 1       Accuracy.append(round(correct/total, 3))       print("Accuracy: ",Accuracy)

这是我基于Sentdex的PyTorch神经网络,通过大学课程管理员提供的数据集进行训练,该数据集通过函数get_strange_symbol_loader(batch_size=128)导入。

当我运行这段代码时,它告诉我每轮训练的准确性应该是1.0。然而,在包含epoch的for循环迭代后运行#Test块会得到更现实的结果。这是为什么呢?

我的目标是绘制测试准确性与epoch数量的关系图,以找到模型在开始过拟合之前的最佳epoch数量。


回答:

您在以下代码块中同时增加了correcttotal的值

if torch.argmax(i) == label[idx]:    correct += 1    total += 1

因此两者总是相同的值,一个除以另一个得到1.0

检查您的缩进,我认为从total +=1移除一个制表符应该可以解决问题。

编辑:我假设您说的“在运行#test块之后…”是指您运行了另一个可能不同的代码片段(可能是正确缩进的)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注