线性回归实现中的问题

我刚开始学习机器学习,尝试使用numpy从头开始实现向量化的线性回归。我尝试使用y=x来测试这个实现,但我的损失反而在增加,我无法理解这是为什么。如果有人能指出这是怎么回事,将会非常有帮助。提前感谢!

import numpy as npclass LinearRegressor(object):    def __init__(self, num_features):        self.num_features = num_features        self.w = np.random.randn(num_features, 1).astype(np.float32)        self.b = np.array(0.0).astype(np.float32)    def forward(self, x):        return np.dot(x, self.w) + self.b    @staticmethod    def loss(y_pred, y_true):        l = np.average(np.power(y_pred - y_true, 2)) / 2        return l    def calculate_gradients(self, x, y_pred, y_true):        self.dl_dw = np.dot(x.T, y_pred - y_true) / len(x)        self.dl_db = np.mean(y_pred - y_true)    def optimize(self, step_size):        self.w -= step_size*self.dl_dw        self.b -= step_size*self.dl_db    def train(self, x, y, step_size=1.0):        y_pred = self.forward(x)        l = self.loss(y_pred=y_pred, y_true=y)        self.calculate_gradients(x=x, y_pred=y_pred, y_true=y)        self.optimize(step_size=step_size)        return l    def evaluate(self, x, y):        return self.loss(self.forward(x), y_true)check_reg = LinearRegressor(num_features=1)x = np.array(list(range(1000))).reshape(-1, 1)y = xlosses = []for iteration in range(100):    loss = check_reg.train(x=x,y=y, step_size=0.001)    losses.append(loss)    if iteration % 1 == 0:        print("Iteration: {}".format(iteration))        print(loss)

输出

Iteration: 0612601.7859402705Iteration: 167456013215.98818Iteration: 27427849474110884.0Iteration: 38.179099502901393e+20Iteration: 49.006330707513148e+25Iteration: 59.917228672922966e+30Iteration: 61.0920254505132042e+36Iteration: 71.2024725981084638e+41Iteration: 81.324090295064888e+46Iteration: 91.4580083421516024e+51Iteration: 101.60547085025467e+56Iteration: 111.7678478362285333e+61Iteration: 121.946647415292399e+66Iteration: 132.1435307416407376e+71Iteration: 142.3603265498975516e+76Iteration: 152.599049318486855e+81Iteration: 16nanIteration: 17nanIteration: 18nanIteration: 19nanIteration: 20nanIteration: 21nanIteration: 22nanIteration: 23nanIteration: 24nanIteration: 25nanIteration: 26nanIteration: 27nanIteration: 28nanIteration: 29nanIteration: 30nanIteration: 31nanIteration: 32nanIteration: 33nanIteration: 34nanIteration: 35nanIteration: 36nanIteration: 37nanIteration: 38nanIteration: 39nanIteration: 40nanIteration: 41nanIteration: 42nanIteration: 43nanIteration: 44nanIteration: 45nanIteration: 46nanIteration: 47nanIteration: 48nanIteration: 49nanIteration: 50nanIteration: 51nanIteration: 52nanIteration: 53nanIteration: 54nanIteration: 55nanIteration: 56nanIteration: 57nanIteration: 58nanIteration: 59nanIteration: 60nanIteration: 61nanIteration: 62nanIteration: 63nanIteration: 64nanIteration: 65nanIteration: 66nanIteration: 67nanIteration: 68nanIteration: 69nanIteration: 70nanIteration: 71nanIteration: 72nanIteration: 73nanIteration: 74nanIteration: 75nanIteration: 76nanIteration: 77nanIteration: 78nanIteration: 79nanIteration: 80nanIteration: 81nanIteration: 82nanIteration: 83nanIteration: 84nanIteration: 85nanIteration: 86nanIteration: 87nanIteration: 88nanIteration: 89nanIteration: 90nanIteration: 91nanIteration: 92nanIteration: 93nanIteration: 94nanIteration: 95nanIteration: 96nanIteration: 97nanIteration: 98nanIteration: 99nan

回答:

你的实现没有任何问题。只是你的步长太大,无法收敛。你在优化过程中一直在优化峰值上跳跃,导致错误越来越大。enter image description here将步长改为如下设置:

loss = check_reg.train(x=x,y=y, step_size=0.000001)

你会得到以下结果:

Iteration: 058305.102166924036Iteration: 125952.192344178206Iteration: 211551.585414406314Iteration: 35141.729521746186Iteration: 42288.6353484460747Iteration: 51018.6952280352172Iteration: 6453.4320214875039Iteration: 7201.82728832044089Iteration: 889.83519431606754Iteration: 939.98665864625944Iteration: 1017.798416262435936Iteration: 117.92229454258205Iteration: 123.526272890501929Iteration: 131.5696002444816197Iteration: 140.6986516574778796Iteration: 150.3109875219688626Iteration: 160.13843156434074647Iteration: 170.061616235257299326Iteration: 180.027424318402401473Iteration: 190.012205888201891543Iteration: 200.005434012356344396Iteration: 210.0024188644277583476Iteration: 220.0010770380211645404Iteration: 230.0004796730257022216Iteration: 240.00021339295719587025Iteration: 259.499628306355218e-05Iteration: 264.244764386691682e-05Iteration: 271.8965112443214162e-05Iteration: 288.56069334821767e-06Iteration: 293.848135476439999e-06Iteration: 301.7367004907528985e-06Iteration: 318.07976330965736e-07Iteration: 324.0167090640020525e-07Iteration: 332.253979336583221e-07Iteration: 341.5365746125585947e-07Iteration: 351.2480275459766612e-07Iteration: 361.1147859663321005e-07Iteration: 371.0288427880059631e-07Iteration: 381.0036079530613815e-07Iteration: 399.901975516098116e-08Iteration: 409.901971962009025e-08Iteration: 419.901968407922984e-08Iteration: 429.901964853839991e-08Iteration: 439.901961299760048e-08Iteration: 449.901957745683155e-08Iteration: 459.90195419160931e-08Iteration: 469.901950637538515e-08Iteration: 479.90194708347077e-08Iteration: 489.901943529406073e-08Iteration: 499.901939975344426e-08Iteration: 509.901936421285829e-08Iteration: 519.90193286723028e-08Iteration: 529.901929313177781e-08Iteration: 539.901925759128331e-08Iteration: 549.901922205081931e-08Iteration: 559.90191865103858e-08Iteration: 569.901915096998278e-08Iteration: 579.901911542961026e-08Iteration: 589.901907988926822e-08Iteration: 599.901904434895669e-08Iteration: 609.901900880867564e-08Iteration: 619.901897326842509e-08Iteration: 629.901893772820503e-08Iteration: 639.901890218801546e-08Iteration: 649.901886664785639e-08Iteration: 659.901883110772781e-08Iteration: 669.901879556762973e-08Iteration: 679.901876002756213e-08Iteration: 689.901872448752503e-08Iteration: 699.901868894751843e-08Iteration: 709.901865340754231e-08Iteration: 719.901861786759669e-08Iteration: 729.901858232768157e-08Iteration: 739.901854678779693e-08Iteration: 749.901851124794279e-08Iteration: 759.901847570811914e-08Iteration: 769.901844016832599e-08Iteration: 779.901840462856333e-08Iteration: 789.901836908883116e-08Iteration: 799.901833354912948e-08Iteration: 809.90182980094583e-08Iteration: 819.901826246981762e-08Iteration: 829.901822693020742e-08Iteration: 839.901819139062772e-08Iteration: 849.901815585107851e-08Iteration: 859.90181203115598e-08Iteration: 869.901808477207157e-08Iteration: 879.901804923261384e-08Iteration: 889.90180136931866e-08Iteration: 899.901797815378986e-08Iteration: 909.901794261442361e-08Iteration: 919.901790707508786e-08Iteration: 929.901787153578259e-08Iteration: 939.901783599650782e-08Iteration: 949.901780045726355e-08Iteration: 959.901776491804976e-08Iteration: 969.901772937886647e-08Iteration: 979.901769383971367e-08Iteration: 989.901765830059137e-08Iteration: 999.901762276149956e-08

希望这对你有帮助!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注