如何在训练模型时使用GPU?

我在Kaggle笔记本上运行代码来训练一个ResNet模型。我选择了GPU作为加速器,所以在这方面我没有犯错。我使用以下代码来训练模型:

model.cuda()for epoch in range(10):  model.train(True)  trainloss=0  for x,y in trainloader:    x,y=x.cuda(),y.cuda()    yhat=model(x)    optimizer.zero_grad()    loss=criterion(yhat,y)    loss.backward()    optimizer.step()    trainloss+=loss.item()  print('Epoch {}  Loss: {}'.format(epoch,(trainloss/len(trainloader.dataset))))  model.eval()  testcorrect=0  with torch.no_grad():    for test_x,test_y in testloader:      test_x,test_y=test_x.cuda(),test_y.cuda()      yhat=model(test_x)      _,z=yhat.max(1)      testcorrect+=(test_y==z).sum().item()print('Model Accuracy: ',(testcorrect/len(testloader.dataset)))

网络代码:

model=torchvision.models.resnet18(pretrained=True)num_ftrs=model.fc.in_featuresmodel.fc=nn.Sequential(nn.Linear(num_ftrs,1000),                        nn.ReLU(),                        nn.Linear(1000,2))

你可以看到,我在模型和张量上都使用了.cuda()函数(在训练部分和验证部分)。然而,Kaggle笔记本显示的GPU使用率为0%,而我的CPU使用率高达99%。我是否遗漏了使用GPU训练模型所需的任何代码?


回答:

可能是你的模型没有给GPU足够的工作量。尝试让你的网络更加依赖GPU,例如引入一些带有大量神经元的线性层等,以此来确认在这种情况下你能看到GPU使用率的增加。此外,我注意到测量结果有一点延迟,所以可能你给GPU的工作量是它可以在几分之一秒内完成的,GPU使用率条因此没有机会从0%上升。

你可以分享你实际使用的网络吗?

我可以看到在Kaggle笔记本上使用一个这样的玩具示例时,GPU使用率达到100%(注意这里有2500 x 2500的线性层):

import torchimport torch.nn as nnimport torch.optim as optimimport numpy as nptrainloader = [(torch.Tensor(np.random.randn(1000, 5)), torch.Tensor([1.0] * 1000))] * 1000model = nn.Sequential(nn.Linear(5, 2500), nn.Linear(2500, 1500), nn.Linear(1500, 1))model.cuda()optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.)criterion = lambda x,y : ((x-y)**2).mean()for epoch in range(10):  for x,y in trainloader:    x,y=x.cuda(),y.cuda()    yhat=model(x)    optimizer.zero_grad()    loss=criterion(yhat,y)    loss.backward()    optimizer.step()  print(epoch)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注