我有一个形状为(192, 224, 192, 1)
的numpy数组。最后一个维度是整数类别,我想对其进行one-hot编码。例如,如果我有12个类别,我希望结果数组的形状为(192, 224, 192, 12)
,最后一个维度除了对应原始值的索引位置为1外,其余都为0。
我可以用多个for
循环来实现,但我想知道是否有更好的方法来做这件事。
回答:
如果你知道最大值,可以通过一次索引操作来完成。假设有一个数组a
和m = a.max() + 1
:
out = np.zeros(a.shape[:-1] + (m,), dtype=bool)out[(*np.indices(a.shape[:-1], sparse=True), a[..., 0])] = True
如果你去掉不必要的尾随维度,操作会更简单:
a = np.squeeze(a)out = np.zeros(a.shape + (m,), bool)out[(*np.indices(a.shape, sparse=True), a)] = True
索引中的显式元组是进行星号展开所必需的。
如果你想将这个方法扩展到任意维度,你也可以做到。以下代码将在axis
处插入一个新维度到压缩后的数组中。这里的axis
是新轴在最终数组中的位置,这与np.stack
一致,但与list.insert
不一致:
def onehot(a, axis=-1, dtype=bool): pos = axis if axis >= 0 else a.ndim + axis + 1 shape = list(a.shape) shape.insert(pos, a.max() + 1) out = np.zeros(shape, dtype) ind = list(np.indices(a.shape, sparse=True)) ind.insert(pos, a) out[tuple(ind)] = True return out
如果你有一个单一维度需要扩展,通用解决方案可以找到第一个可用的单一维度:
def onehot2(a, axis=None, dtype=bool): shape = np.array(a.shape) if axis is None: axis = (shape == 1).argmax() if shape[axis] != 1: raise ValueError(f'Dimension at {axis} is non-singleton') shape[axis] = a.max() + 1 out = np.zeros(shape, dtype) ind = list(np.indices(a.shape, sparse=True)) ind[axis] = a out[tuple(ind)] = True return out
要使用最后一个可用的单一维度,请将axis = (shape == 1).argmax()
替换为
axis = a.ndim - 1 - (shape[::-1] == 1).argmax()
以下是一些示例用法:
>>> np.random.seed(0x111)>>> x = np.random.randint(5, size=(3, 2))>>> xarray([[2, 3], [3, 1], [4, 0]])>>> a = onehot(x, axis=-1, dtype=int)>>> a.shape(3, 2, 5)>>> aarray([[[0, 0, 1, 0, 0], # 2 [0, 0, 0, 1, 0]], # 3 [[0, 0, 0, 1, 0], # 3 [0, 1, 0, 0, 0]], # 1 [[0, 0, 0, 0, 1], # 4 [1, 0, 0, 0, 0]]] # 0>>> b = onehot(x, axis=-2, dtype=int)>>> b.shape(3, 5, 2)>>> barray([[[0, 0], [0, 0], [1, 0], [0, 1], [0, 0]], [[0, 0], [0, 1], [0, 0], [1, 0], [0, 0]], [[0, 1], [0, 0], [0, 0], [0, 0], [1, 0]]])
onehot2
要求你标记你想添加的维度为单一维度:
>>> np.random.seed(0x111)>>> y = np.random.randint(5, size=(3, 1, 2, 1))>>> yarray([[[[2], [3]]], [[[3], [1]]], [[[4], [0]]]])>>> c = onehot2(y, axis=-1, dtype=int)>>> c.shape(3, 1, 2, 5)>>> carray([[[[0, 0, 1, 0, 0], [0, 0, 0, 1, 0]]], [[[0, 0, 0, 1, 0], [0, 1, 0, 0, 0]]], [[[0, 0, 0, 0, 1], [1, 0, 0, 0, 0]]]])>>> d = onehot2(y, axis=-2, dtype=int)ValueError: Dimension at -2 is non-singleton>>> e = onehot2(y, dtype=int)>>> e.shape(3, 5, 2, 1)>>> e.squeeze()array([[[0, 0], [0, 0], [1, 0], [0, 1], [0, 0]], [[0, 0], [0, 1], [0, 0], [1, 0], [0, 0]], [[0, 1], [0, 0], [0, 0], [0, 0], [1, 0]]])