我使用PyTorch训练了一个Mask RCNN网络,并尝试使用获得的权重来预测图像中苹果的位置。
我使用了这个论文中的数据集,这是正在使用的代码的GitHub链接
我只是按照ReadMe文件中提供的说明进行操作。
这是我在提示符中输入的命令(已删除个人信息)
python predict_rcnn.py –data_path “my_directory\datasets\apples-minneapple\detection” –output_file “my_directory\samples\apples\predictions” –weight_file “my_directory\samples\apples\weights\model_19.pth” –mrcnn
model_19.pth是第19个epoch后生成的所有权重的文件
错误如下:
Loading modelTraceback (most recent call last): File "predict_rcnn.py", line 122, in <module> main(args) File "predict_rcnn.py", line 77, in main model.load_state_dict(checkpoint['model'], strict=False)KeyError: 'model'
为方便起见,我会粘贴predict_rcnn.py:
import osimport torchimport torch.utils.dataimport torchvisionimport numpy as npfrom data.apple_dataset import AppleDatasetfrom torchvision.models.detection.faster_rcnn import FastRCNNPredictorfrom torchvision.models.detection.mask_rcnn import MaskRCNNPredictorimport utility.utils as utilsimport utility.transforms as T####################################################### Predict with either a Faster-RCNN or Mask-RCNN predictor# using the MinneApple dataset######################################################def get_transform(train): transforms = [] transforms.append(T.ToTensor()) if train: transforms.append(T.RandomHorizontalFlip(0.5)) return T.Compose(transforms)def get_maskrcnn_model_instance(num_classes): # load an instance segmentation model pre-trained pre-trained on COCO model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False) # get number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features # replace the pre-trained head with a new one model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) # now get the number of input features for the mask classifier in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels hidden_layer = 256 # and replace the mask predictor with a new one model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes) return modeldef get_frcnn_model_instance(num_classes): # load an instance segmentation model pre-trained pre-trained on COCO model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False) # get number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features # replace the pre-trained head with a new one model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return modeldef main(args): num_classes = 2 device = args.device # Load the model from print("Loading model") # Create the correct model type if args.mrcnn: model = get_maskrcnn_model_instance(num_classes) else: model = get_frcnn_model_instance(num_classes) # Load model parameters and keep on CPU checkpoint = torch.load(args.weight_file, map_location=device) #checkpoint = torch.load(args.weight_file, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['model'], strict=False) model.eval() print("Creating data loaders") dataset_test = AppleDataset(args.data_path, get_transform(train=False)) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=1, collate_fn=utils.collate_fn) # Create output directory base_path = os.path.dirname(args.output_file) if not os.path.exists(base_path): os.makedirs(base_path) # Predict on bboxes on each image f = open(args.output_file, 'a') for image, targets in data_loader_test: image = list(img.to(device) for img in image) outputs = model(image) for ii, output in enumerate(outputs): img_id = targets[ii]['image_id'] img_name = data_loader_test.dataset.get_img_name(img_id) print("Predicting on image: {}".format(img_name)) boxes = output['boxes'].detach().numpy() scores = output['scores'].detach().numpy() im_names = np.repeat(img_name, len(boxes), axis=0) stacked = np.hstack((im_names.reshape(len(scores), 1), boxes.astype(int), scores.reshape(len(scores), 1))) # File to write predictions to np.savetxt(f, stacked, fmt='%s', delimiter=',', newline='\n')if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='PyTorch Detection') parser.add_argument('--data_path', required=True, help='path to the data to predict on') parser.add_argument('--output_file', required=True, help='path where to write the prediction outputs') parser.add_argument('--weight_file', required=True, help='path to the weight file') parser.add_argument('--device', default='cuda', help='device to use. Either cpu or cuda') model = parser.add_mutually_exclusive_group(required=True) model.add_argument('--frcnn', action='store_true', help='use a Faster-RCNN model') model.add_argument('--mrcnn', action='store_true', help='use a Mask-RCNN model') args = parser.parse_args() main(args)
回答:
保存的检查点中没有'model'
参数。如果你查看train_rcnn.py:106
:
torch.save(model.state_dict(), os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
你会发现他们只保存了模型参数。应该像这样:
torch.save({ "model": model.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict()}, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
这样加载后,你会得到一个包含'model'
和其他他们似乎想要保留的参数的字典。
这似乎是他们代码中的一个bug。