使用One-Hot编码处理维度大于2的numpy数组

我有一个形状为(192, 224, 192, 1)的numpy数组。最后一个维度是整数类别,我想对其进行one-hot编码。例如,如果我有12个类别,我希望结果数组的形状为(192, 224, 192, 12),最后一个维度除了对应原始值的索引位置为1外,其余都为0。

我可以用多个for循环来实现,但我想知道是否有更好的方法来做这件事。


回答:

如果你知道最大值,可以通过一次索引操作来完成。假设有一个数组am = a.max() + 1

out = np.zeros(a.shape[:-1] + (m,), dtype=bool)out[(*np.indices(a.shape[:-1], sparse=True), a[..., 0])] = True

如果你去掉不必要的尾随维度,操作会更简单:

a = np.squeeze(a)out = np.zeros(a.shape + (m,), bool)out[(*np.indices(a.shape, sparse=True), a)] = True

索引中的显式元组是进行星号展开所必需的。

如果你想将这个方法扩展到任意维度,你也可以做到。以下代码将在axis处插入一个新维度到压缩后的数组中。这里的axis是新轴在最终数组中的位置,这与np.stack一致,但与list.insert不一致:

def onehot(a, axis=-1, dtype=bool):    pos = axis if axis >= 0 else a.ndim + axis + 1    shape = list(a.shape)    shape.insert(pos, a.max() + 1)    out = np.zeros(shape, dtype)    ind = list(np.indices(a.shape, sparse=True))    ind.insert(pos, a)    out[tuple(ind)] = True    return out

如果你有一个单一维度需要扩展,通用解决方案可以找到第一个可用的单一维度:

def onehot2(a, axis=None, dtype=bool):    shape = np.array(a.shape)    if axis is None:        axis = (shape == 1).argmax()    if shape[axis] != 1:        raise ValueError(f'Dimension at {axis} is non-singleton')    shape[axis] = a.max() + 1    out = np.zeros(shape, dtype)    ind = list(np.indices(a.shape, sparse=True))    ind[axis] = a    out[tuple(ind)] = True    return out

要使用最后一个可用的单一维度,请将axis = (shape == 1).argmax()替换为

axis = a.ndim - 1 - (shape[::-1] == 1).argmax()

以下是一些示例用法:

>>> np.random.seed(0x111)>>> x = np.random.randint(5, size=(3, 2))>>> xarray([[2, 3],       [3, 1],       [4, 0]])>>> a = onehot(x, axis=-1, dtype=int)>>> a.shape(3, 2, 5)>>> aarray([[[0, 0, 1, 0, 0],    # 2        [0, 0, 0, 1, 0]],   # 3       [[0, 0, 0, 1, 0],    # 3        [0, 1, 0, 0, 0]],   # 1       [[0, 0, 0, 0, 1],    # 4        [1, 0, 0, 0, 0]]]   # 0>>> b = onehot(x, axis=-2, dtype=int)>>> b.shape(3, 5, 2)>>> barray([[[0, 0],        [0, 0],        [1, 0],        [0, 1],        [0, 0]],       [[0, 0],        [0, 1],        [0, 0],        [1, 0],        [0, 0]],       [[0, 1],        [0, 0],        [0, 0],        [0, 0],        [1, 0]]])

onehot2要求你标记你想添加的维度为单一维度:

>>> np.random.seed(0x111)>>> y = np.random.randint(5, size=(3, 1, 2, 1))>>> yarray([[[[2],         [3]]],       [[[3],         [1]]],       [[[4],         [0]]]])>>> c = onehot2(y, axis=-1, dtype=int)>>> c.shape(3, 1, 2, 5)>>> carray([[[[0, 0, 1, 0, 0],         [0, 0, 0, 1, 0]]],       [[[0, 0, 0, 1, 0],         [0, 1, 0, 0, 0]]],       [[[0, 0, 0, 0, 1],         [1, 0, 0, 0, 0]]]])>>> d = onehot2(y, axis=-2, dtype=int)ValueError: Dimension at -2 is non-singleton>>> e = onehot2(y, dtype=int)>>> e.shape(3, 5, 2, 1)>>> e.squeeze()array([[[0, 0],        [0, 0],        [1, 0],        [0, 1],        [0, 0]],       [[0, 0],        [0, 1],        [0, 0],        [1, 0],        [0, 0]],       [[0, 1],        [0, 0],        [0, 0],        [0, 0],        [1, 0]]])

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注