我编写了一个脚本,使用我在MxNet上训练的模型对单个输入图像进行分类。为了对传入的图像进行分类,我将它们前馈通过网络。
简而言之,我正在做以下工作:
symbol, arg_params, aux_params = mx.model.load_checkpoint('model-prefix', 42)model = mx.mod.Module(symbol=symbol, context=mx.cpu())model.bind(data_shapes=[('data', (1, 3, 224, 244))], for_training=False)model.set_params(arg_params, aux_params)# ... 加载图像并调整大小 ...# img 是要分类的图像,作为形状为 (3, 244, 244) 的 numpy 数组Batch = namedtuple('Batch', ['data'])self._model.forward(Batch(data=[mx.nd.array(img)]))probabilities = self._model.get_outputs()[0].asnumpy()print(str(probabilities))
这运行得很好,除了我收到了以下警告:
UserWarning: Data provided by label_shapes don't match names specified by label_names ([] vs. ['softmax_label'])
我应该做哪些更改来避免这个警告? 我不清楚 label_shapes 和 label_names 参数的用途,以及我应该用什么来填充它们。
注意:我找到了一些关于它们的讨论,但没有一个能帮助我解决问题。同样,MxNet的文档也没有提供关于这些参数的详细信息,以及它们应该如何填充的说明。
回答:
设置 label_names=None
和 allow_missing=True
。这应该可以消除警告。
model = mx.mod.Module(symbol=symbol, context=mx.cpu(), label_names=None)...model.set_params(arg_params, aux_params, allow_missing=True)
如果你好奇为什么最初会打印这个警告,
每个模块都有一个关联的标签。当这个模型被训练时,softmax_label
被用作标签(很可能是由于输出层是一个名为 ‘softmax’ 的 softmax 层)。当模型从文件中加载时,创建的模块使用 softmax_label
作为模块的标签。
>>>print(model.label_names)['softmax_label']
然后调用 model.bind
时没有提供 label_shapes。
model.bind(data_shapes=[('data', (1, 3, 224, 244))], for_training=False)
MXNet 发现模块中有一个标签,但在绑定时没有提供这个标签,因此会抱怨 – 这就是你看到的警告信息。
我认为如果 bind
调用时设置了 for_training=False
,MXNet 不应该抱怨缺少标签。我已经创建了这个问题:https://github.com/dmlc/mxnet/issues/6958
然而,对于这种从磁盘加载模型的特定情况,我们可以将标签设置为 None
,这样当 bind
没有提供标签时,MXNet 就不会抱怨 – 这正是建议的修复方法所做的。