如何在PyTorch中测试单张图像

我在PyTorch中创建了一个模型,并且运行得很好,但是当我想要测试单张图像时,设置batch_size=1总是返回第二个类(在这种情况下是狗)。

我尝试了batch大于1的情况,在所有情况下都能正常工作!

模型架构如下:

model = models.densenet121(pretrained=True)for param in model.parameters():param.requires_grad = Falsefrom collections import OrderedDictclassifier = nn.Sequential(OrderedDict([                          ('fc1', nn.Linear(1024, 500)),                          ('relu', nn.ReLU()),                          ('fc2', nn.Linear(500, 2)),                          ('output', nn.LogSoftmax(dim=1))                          ]))model.classifier = classifier

所以我的张量形状是 [batch, 3, 224, 224]

我尝试了以下方法:

resizereshapeunsqueeze(0)

当只有一张图像时,响应总是 [[0.4741, 0.5259]]

我的测试代码

from PIL import *msize = 256loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])def image_loader(image_name):    """load image, returns cuda tensor"""    image = Image.open(image_name)    image = loader(image).float()    image = image.unsqueeze(0)     return image.cuda()image = image_loader('Cat_Dog_data/test/cat/cat.16.jpg') with torch.no_grad():    logits = model.forward(image)ps = torch.exp(logits)_, predTest = torch.max(ps,1)print(ps) ## same value in all casesimagen_mostrar = images[ii].to('cpu') helper.imshow(imagen_mostrar,title=clas_perro_gato(predTest), normalize=True)

第二个测试代码

andrea_data = datasets.ImageFolder(data_dir + '/andrea', transform=test_transforms)andrealoader = torch.utils.data.DataLoader(andrea_data, batch_size=1, shuffle=True)dataiter = iter(andrealoader)images, labels = dataiter.next()images, labels = images.to(device), labels.to(device)ps = torch.exp(model.forward(images))_, predTest = torch.max(ps,1) print(ps.float())

如果我将batch_size改为1,总是返回一个张量,说明这是一只狗,例如[0.43,0.57]。

谢谢!


回答:

我意识到我的模型没有处于eval模式。

所以我只是添加了model.eval(),现在无论batch大小是多少都能正常工作了

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注