我需要一些帮助来优化一段Python代码

我正在使用Python开发一个KNN分类器,但遇到了一些问题。以下代码段的执行时间在7.5秒到9.0秒之间,而我需要运行60,000次。

        for fold in folds:              for dot2 in fold:                """                distances[x][0] = Class of the dot2                distances[x][1] = distance between dot1 and dot2                """                distances.append([dot2[0], calc_distance(dot1[1:], dot2[1:], method)])

变量”folds”是一个包含10个折叠的列表,总共包含60,000个以.csv格式存储的图像输入。每个点的第一个值是它所属的类别。所有值都是整数。有没有办法让这行代码运行得更快?

这是calc_distance函数

def calc_distancia(dot1, dot2, distance):if distance == "manhanttan":    total = 0    #for each coord, take the absolute difference    for x in range(0, len(dot1)):        total = total + abs(dot1[x] - dot2[x])    return totalelif distance == "euclidiana":    total = 0    for x in range(0, len(dot1)):        total = total + (dot1[x] - dot2[x])**2    return math.sqrt(total)elif distance == "supremum":    total = 0    for x in range(0, len(dot1)):        if abs(dot1[x] - dot2[x]) > total:            total = abs(dot1[x] - dot2[x])    return totalelif distance == "cosseno":    dist = 0    p1_p2_mul = 0    p1_sum = 0    p2_sum = 0    for x in range(0, len(dot1)):        p1_p2_mul = p1_p2_mul + dot1[x]*dot2[x]        p1_sum = p1_sum + dot1[x]**2        p2_sum = p2_sum + dot2[x]**2    p1_sum = math.sqrt(p1_sum)    p2_sum = math.sqrt(p2_sum)    quociente = p1_sum*p2_sum    dist = p1_p2_mul/quociente    return dist

编辑:找到了至少对于”manhanttan”方法加速的方法。原来的代码是:

    if distance == "manhanttan":    total = 0    #for each coord, take the absolute difference    for x in range(0, len(dot1)):        total = total + abs(dot1[x] - dot2[x])    return total

我改成了

    if distance == "manhanttan":    totalp1 = 0    totalp2 = 0    #for each coord, take the absolute difference    for x in range(0, len(dot1)):        totalp1 += dot1[x]        totalp2 += dot2[x]    return abs(totalp1-totalp2)

abs()调用非常耗时


回答:

有许多关于“Python性能分析”的指南;你应该搜索一些,阅读它们,并逐步进行性能分析过程,以确保你知道工作中哪些部分耗时最多。

但如果这是你工作的核心,那么可以合理地推测calc_distance是消耗大部分运行时间的地方。

要深入优化它,可能需要使用NumPy加速的数学运算或类似的、更低级的方法。

作为一种快速且不那么侵入性的性能分析和重写方法,可以尝试安装PyPy版本的Python,并在其上运行。我看到过与标准(CPython)实现相比,容易获得2倍或更多的加速效果。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注