如何在不使用循环的情况下获取NumPy数组中每个元素大于或小于的元素索引?

我正在从头开始编写决策树算法,目前我试图将数据分成几组,每组包含大于或等于或小于NumPy数组中某个连续DataFrame列的值的元素,并获取这些分组的目标值的平均值。我目前的代码如下:

for i in range(len(columns)):    col = columns[i]    # cont - 我的DataFrame中连续列的列表    if col in cont:        values  = xs[col].values        targets = y.values        for j in range(len(values)):            value = values[j]            greater_idx = np.where(values >= value)[0]            less_idx    = np.where(values <  value)[0]            targets_greater = targets[greater_idx].sum()            targets_less    = targets[less_idx]   .sum()        print(targets_greater/(j+1))        print(targets_less   /(j+1))

xs DataFrame的长度接近40万,因此循环非常慢,每次都会导致我的Jupyter Notebook内核崩溃。我知道应该有办法完全去掉这个循环,但我不知道该怎么做。


回答:

与其使用向量化方式进行比较,还有很多算法改进的空间:

  1. 使用np.argsort获取xs[col].values的排序索引(sorted_idxs)。
  2. 使用np.insert(np.cumsum(targets[sorted_idxs]), 0, 0)[:-1]可以为xs[col].values中的每个值得到一个target_less向量。
  3. target_less[0](0)是xs[col].values中最低元素的target_less值 – 要“取消排序”target_less,可以使用unsort_idx = np.argsort(sorted_idxs)target_less[unsort_idx]

现在您已经为数组中的所有值获取了所有target_less值(当然,target_greater可以通过targets.sum() - target_less轻松获得)。

编辑:

以下是与建议配套的代码:

import numpy as npimport pandas as pdxs = pd.DataFrame(np.random.random(10000))y = pd.Series(np.random.randint(0, 2, size=10000))sorted_idxs = np.argsort(xs[0].values)sorted_values = xs[0].values[sorted_idxs]sorted_targets = y.values[sorted_idxs]sorted_targets_less = np.insert(np.cumsum(sorted_targets), 0, 0)[:-1]unsorted_idxs = np.argsort(sorted_idxs)targets_less = sorted_targets_less[unsorted_idxs]for i, target_less_value in enumerate(targets_less):    assert target_less_value == y.values[np.where(xs.values < xs.values[i])[0]].sum()

一个警告:上述假设xs.values中的值是严格不同的。如果您有重复的值,那么您需要调整执行累积和的部分。

Related Posts

如何对SVC进行超参数调优?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

如何在初始训练后向模型添加训练数据?

我想在我的scikit-learn模型已经训练完成后再…

使用Google Cloud Function并行运行带有不同用户参数的相同训练作业

我正在寻找一种方法来并行运行带有不同用户参数的相同训练…

加载Keras模型,TypeError: ‘module’ object is not callable

我已经在StackOverflow上搜索并阅读了文档,…

在计算KNN填补方法中特定列中NaN值的”距离平均值”时

当我从头开始实现KNN填补方法来处理缺失数据时,我遇到…

使用巨大的S3 CSV文件或直接从预处理的关系型或NoSQL数据库获取数据的机器学习训练/测试工作

已关闭。此问题需要更多细节或更清晰的说明。目前不接受回…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注