我正在研究这个脚本,其中有一个代码块考虑了两种选项,即DataParallel
和DistributedDataParallel
,如下所示:
if not args.distributed: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda()else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model)
如果我不想使用这两种选项,并且我想在没有DataParallel
的情况下运行它,我该怎么做呢?
我该如何定义我的模型,使其作为普通的nn
运行,而不进行任何并行化处理呢?
回答:
如果你不想使用这些功能,你可以简单地移除包装器,直接使用模型本身:
if not args.distributed: if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = model.features model.cuda() else: model = model.cuda()else: model.cuda() model = model
这样做可以将代码修改降到最低。当然,由于你对并行化不感兴趣,你可以将整个if
语句简化为类似于下面的形式:
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = model.featuresmodel = model.cuda()
请注意,这段代码假设你在GPU上运行。