在2D列表中获取正确的值

我有一个多项式回归图表,并且试图使用下一个X值来查找预测的y值。

import numpy as npimport matplotlib.pyplot as pltfrom sklearn.linear_model import LinearRegressionimport jsonimport matplotlib.pyplot as pltfrom sklearn.pipeline import make_pipelinefrom sklearn.preprocessing import PolynomialFeatureswith open('/Users/aus10/Desktop/PGA/Data_Cleanup/Combined_Player_Stats.json') as json_file:    players_data = json.load(json_file)for obj in players_data:    obj['Scrambling_List'] = [i for i in obj['Scrambling_List'] if i]for obj in players_data:    def create_2d_lst(lst):        try:            if len(lst) < 1:                return [0, 0]            else:                return [[i, j] for i, j in enumerate(lst)]        except:                pass    try:             scrambling = create_2d_lst(obj['Scrambling_List'])        total_putts_GIR = create_2d_lst(obj['Total_Putts_GIR_List'])        SG_Putting = create_2d_lst(obj['SG_Putting_List'])    except Exception:        pass    data = scrambling    X = np.array(data)[:,0].reshape(-1,1)    y = np.array(data)[:,1].reshape(-1,1)    poly_reg = PolynomialFeatures(degree=4)    X_poly = poly_reg.fit_transform(X)    pol_reg = LinearRegression()    pol_reg.fit(X_poly, y)    predicted_y = poly_reg.fit_transform(X)    m = pol_reg.coef_    c = pol_reg.intercept_    prediction_value = (len(X) + 1)    prediction = pol_reg.predict(poly_reg.fit_transform([[prediction_value]]))    def viz_polymonial():        plt.scatter(X, y, color='red')        plt.plot(X, pol_reg.predict(poly_reg.fit_transform(X)), color='blue')        plt.plot(prediction, marker='x', color='green')        plt.title('Projected Scrambling Percentage')        plt.xlabel('Tournaments')        plt.ylabel('Scrambling Percentage')        plt.show()        return    viz_polymonial()        print(obj['Name'], prediction)

当我使用prediction = prediction_value = (len(X) + 1)prediction = pol_reg.predict(poly_reg.fit_transform([[prediction_value]]))时,我应该得到X的下一个值,但它返回了0,而应该返回len(X) + 1。我需要将X的正确值设置为预测值。我不确定为什么是0,因为当我打印预测值时,我得到了正确的值

这是json文件的副本

[  {    "Name": "Aaron Baddeley",    "Tournaments": [      {        "Scrambling": 71.43,        "Total_Putts_GIR": 75,        "SG_Putting": 0.31,        "Tournament": "Safeway_Open",        "Date": "08-26-2019"      },      {        "Scrambling": 55.56,        "Total_Putts_GIR": 92,        "SG_Putting": 0.03,        "Tournament": "Shriners_Hospital_for_Children_Open",        "Date": "10-08-2019"      },      {        "Scrambling": 40,        "Total_Putts_GIR": 47,        "SG_Putting": -0.14,        "Tournament": "Houston",        "Date": "10-10-2019"      },      {        "Scrambling": 71.43,        "Total_Putts_GIR": 93,        "SG_Putting": -0.37,        "Tournament": "Waste_Management",        "Date": "01-30-2020"      },      {        "Scrambling": 75,        "Total_Putts_GIR": 29,        "SG_Putting": 0.69,        "Tournament": "The_Genesis",        "Date": "02-13-2020"      },      {        "Scrambling": 71.43,        "Total_Putts_GIR": 38,        "SG_Putting": -0.82,        "Tournament": "RBC_Heritage",        "Date": "06-18-2020"      },      {        "Scrambling": 50,        "Total_Putts_GIR": 30,        "SG_Putting": 0.88,        "Tournament": "Travelers",        "Date": "06-25-2020"      },      {        "Scrambling": 42.86,        "Total_Putts_GIR": 53,        "SG_Putting": -1.18,        "Tournament": "Rocket_Mortgage",        "Date": "07-02-2020"      },      {        "Scrambling": 43.75,        "Total_Putts_GIR": 33,        "SG_Putting": 1.41,        "Tournament": "Workday",        "Date": "07-09-2020"      }    ],    "Scrambling_List": [      71.43,      55.56,      40,      71.43,      75,      71.43,      50,      42.86,      43.75    ],    "Total_Putts_GIR_List": [      75,      92,      47,      93,      29,      38,      30,      53,      33    ],    "SG_Putting_List": [      0.31,      0.03,      -0.14,      -0.37,      0.69,      -0.82,      0.88,      -1.18,      1.41    ]  }]

回答:

已经解决了。图表中没有X值,所以它自动设置为0。

import numpy as npimport matplotlib.pyplot as pltfrom sklearn.linear_model import LinearRegressionimport jsonimport matplotlib.pyplot as pltfrom sklearn.pipeline import make_pipelinefrom sklearn.preprocessing import PolynomialFeatureswith open('/Users/aus10/Desktop/PGA/Data_Cleanup/Combined_Player_Stats.json') as json_file:    players_data = json.load(json_file)for obj in players_data:    obj['Scrambling_List'] = [i for i in obj['Scrambling_List'] if i]for obj in players_data:    def create_2d_lst(lst):        try:            if len(lst) < 1:                return [0, 0]            else:                return [[i, j] for i, j in enumerate(lst)]        except:                pass    try:             scrambling = create_2d_lst(obj['Scrambling_List'])        total_putts_GIR = create_2d_lst(obj['Total_Putts_GIR_List'])        SG_Putting = create_2d_lst(obj['SG_Putting_List'])    except Exception:        pass    data = scrambling    X = np.array(data)[:,0].reshape(-1,1)    y = np.array(data)[:,1].reshape(-1,1)    poly_reg = PolynomialFeatures(degree=4)    X_poly = poly_reg.fit_transform(X)    pol_reg = LinearRegression()    pol_reg.fit(X_poly, y)    predicted_y = poly_reg.fit_transform(X)    m = pol_reg.coef_    c = pol_reg.intercept_    prediction = pol_reg.predict(poly_reg.fit_transform([[len(X)+1]]))    def viz_polymonial():        plt.scatter(X, y, color='red')        plt.plot(X, pol_reg.predict(poly_reg.fit_transform(X)), color='blue')        plt.plot(len(X)+1, pol_reg.predict(poly_reg.fit_transform([[len(X)+1]])), marker='x', color='green')        plt.title('Projected Scrambling Percentage')        plt.xlabel('Tournaments')        plt.ylabel('Scrambling Percentage')        plt.show()        return    viz_polymonial()    print(obj['Name'], prediction)

Related Posts

如何对SVC进行超参数调优?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

如何在初始训练后向模型添加训练数据?

我想在我的scikit-learn模型已经训练完成后再…

使用Google Cloud Function并行运行带有不同用户参数的相同训练作业

我正在寻找一种方法来并行运行带有不同用户参数的相同训练…

加载Keras模型,TypeError: ‘module’ object is not callable

我已经在StackOverflow上搜索并阅读了文档,…

在计算KNN填补方法中特定列中NaN值的”距离平均值”时

当我从头开始实现KNN填补方法来处理缺失数据时,我遇到…

使用巨大的S3 CSV文件或直接从预处理的关系型或NoSQL数据库获取数据的机器学习训练/测试工作

已关闭。此问题需要更多细节或更清晰的说明。目前不接受回…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注