如何在不使用循环的情况下获取NumPy数组中每个元素大于或小于的元素索引?

我正在从头开始编写决策树算法,目前我试图将数据分成几组,每组包含大于或等于或小于NumPy数组中某个连续DataFrame列的值的元素,并获取这些分组的目标值的平均值。我目前的代码如下:

for i in range(len(columns)):    col = columns[i]    # cont - 我的DataFrame中连续列的列表    if col in cont:        values  = xs[col].values        targets = y.values        for j in range(len(values)):            value = values[j]            greater_idx = np.where(values >= value)[0]            less_idx    = np.where(values <  value)[0]            targets_greater = targets[greater_idx].sum()            targets_less    = targets[less_idx]   .sum()        print(targets_greater/(j+1))        print(targets_less   /(j+1))

xs DataFrame的长度接近40万,因此循环非常慢,每次都会导致我的Jupyter Notebook内核崩溃。我知道应该有办法完全去掉这个循环,但我不知道该怎么做。


回答:

与其使用向量化方式进行比较,还有很多算法改进的空间:

  1. 使用np.argsort获取xs[col].values的排序索引(sorted_idxs)。
  2. 使用np.insert(np.cumsum(targets[sorted_idxs]), 0, 0)[:-1]可以为xs[col].values中的每个值得到一个target_less向量。
  3. target_less[0](0)是xs[col].values中最低元素的target_less值 – 要“取消排序”target_less,可以使用unsort_idx = np.argsort(sorted_idxs)target_less[unsort_idx]

现在您已经为数组中的所有值获取了所有target_less值(当然,target_greater可以通过targets.sum() - target_less轻松获得)。

编辑:

以下是与建议配套的代码:

import numpy as npimport pandas as pdxs = pd.DataFrame(np.random.random(10000))y = pd.Series(np.random.randint(0, 2, size=10000))sorted_idxs = np.argsort(xs[0].values)sorted_values = xs[0].values[sorted_idxs]sorted_targets = y.values[sorted_idxs]sorted_targets_less = np.insert(np.cumsum(sorted_targets), 0, 0)[:-1]unsorted_idxs = np.argsort(sorted_idxs)targets_less = sorted_targets_less[unsorted_idxs]for i, target_less_value in enumerate(targets_less):    assert target_less_value == y.values[np.where(xs.values < xs.values[i])[0]].sum()

一个警告:上述假设xs.values中的值是严格不同的。如果您有重复的值,那么您需要调整执行累积和的部分。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注