使用Python替换矩阵中的特定值

我有一个m x n的矩阵,每一行是一个样本,每一列是一个类别。每行包含每个类别的软最大概率。我想将每行的最大值替换为1,其余值替换为0。如何在Python中高效地完成这个操作?


回答:

我认为对你这个问题的最佳回答是使用矩阵类型对象。

考虑到矩阵的大部分元素都是零,使用稀疏矩阵在存储大量大尺寸矩阵时应该是最节省内存的。这在内存使用方面应该优于直接使用numpy数组,特别是对于两个维度都非常大的矩阵,即使计算速度可能没有那么快,但在内存使用上更有优势。

import numpy as npimport scipy       #旧版本可能需要`import scipy.sparse`matrix = np.matrix(np.random.randn(10, 5))maxes = matrix.argmax(axis=1).A1                                 # 之前是.A[:,0],稍微快一些,但.A1看起来更易读n_rows = len(matrix)  # 可以使用matrix.shape[0],但那会更慢data = np.ones(n_rows)row = np.arange(n_rows)sparse_matrix = scipy.sparse.coo_matrix((data, (row, maxes)),                                         shape=matrix.shape,                                         dtype=np.int8)

这个稀疏矩阵对象相对于常规矩阵对象应该非常轻量,因为它不会无谓地跟踪其中的每个零。要将其转换为普通矩阵:

sparse_matrix.todense()

返回:

matrix([[0, 0, 0, 0, 1],        [0, 0, 1, 0, 0],        [0, 0, 1, 0, 0],        [0, 0, 0, 0, 1],        [1, 0, 0, 0, 0],        [0, 0, 1, 0, 0],        [0, 0, 0, 1, 0],        [0, 1, 0, 0, 0],        [1, 0, 0, 0, 0],        [0, 0, 0, 1, 0]], dtype=int8)

我们可以将其与matrix进行比较:

matrix([[ 1.41049496,  0.24737968, -0.70849012,  0.24794031,  1.9231408 ],        [-0.08323096, -0.32134873,  2.14154425, -1.30430663,  0.64934781],        [ 0.56249379,  0.07851507,  0.63024234, -0.38683508, -1.75887624],        [-0.41063182,  0.15657594,  0.11175805,  0.37646245,  1.58261556],        [ 1.10421356, -0.26151637,  0.64442885, -1.23544526, -0.91119517],        [ 0.51384883,  1.5901419 ,  1.92496778, -1.23541699,  1.00231508],        [-2.42759787, -0.23592018, -0.33534536,  0.17577329, -1.14793293],        [-0.06051458,  1.24004714,  1.23588228, -0.11727146, -0.02627196],        [ 1.66071534, -0.07734444,  1.40305686, -1.02098911, -1.10752638],        [ 0.12466003, -1.60874191,  1.81127175,  2.26257234, -1.26008476]])

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注