我刚开始学习机器学习,尝试使用numpy从头开始实现向量化的线性回归。我尝试使用y=x来测试这个实现,但我的损失反而在增加,我无法理解这是为什么。如果有人能指出这是怎么回事,将会非常有帮助。提前感谢!
import numpy as npclass LinearRegressor(object): def __init__(self, num_features): self.num_features = num_features self.w = np.random.randn(num_features, 1).astype(np.float32) self.b = np.array(0.0).astype(np.float32) def forward(self, x): return np.dot(x, self.w) + self.b @staticmethod def loss(y_pred, y_true): l = np.average(np.power(y_pred - y_true, 2)) / 2 return l def calculate_gradients(self, x, y_pred, y_true): self.dl_dw = np.dot(x.T, y_pred - y_true) / len(x) self.dl_db = np.mean(y_pred - y_true) def optimize(self, step_size): self.w -= step_size*self.dl_dw self.b -= step_size*self.dl_db def train(self, x, y, step_size=1.0): y_pred = self.forward(x) l = self.loss(y_pred=y_pred, y_true=y) self.calculate_gradients(x=x, y_pred=y_pred, y_true=y) self.optimize(step_size=step_size) return l def evaluate(self, x, y): return self.loss(self.forward(x), y_true)check_reg = LinearRegressor(num_features=1)x = np.array(list(range(1000))).reshape(-1, 1)y = xlosses = []for iteration in range(100): loss = check_reg.train(x=x,y=y, step_size=0.001) losses.append(loss) if iteration % 1 == 0: print("Iteration: {}".format(iteration)) print(loss)
输出
Iteration: 0612601.7859402705Iteration: 167456013215.98818Iteration: 27427849474110884.0Iteration: 38.179099502901393e+20Iteration: 49.006330707513148e+25Iteration: 59.917228672922966e+30Iteration: 61.0920254505132042e+36Iteration: 71.2024725981084638e+41Iteration: 81.324090295064888e+46Iteration: 91.4580083421516024e+51Iteration: 101.60547085025467e+56Iteration: 111.7678478362285333e+61Iteration: 121.946647415292399e+66Iteration: 132.1435307416407376e+71Iteration: 142.3603265498975516e+76Iteration: 152.599049318486855e+81Iteration: 16nanIteration: 17nanIteration: 18nanIteration: 19nanIteration: 20nanIteration: 21nanIteration: 22nanIteration: 23nanIteration: 24nanIteration: 25nanIteration: 26nanIteration: 27nanIteration: 28nanIteration: 29nanIteration: 30nanIteration: 31nanIteration: 32nanIteration: 33nanIteration: 34nanIteration: 35nanIteration: 36nanIteration: 37nanIteration: 38nanIteration: 39nanIteration: 40nanIteration: 41nanIteration: 42nanIteration: 43nanIteration: 44nanIteration: 45nanIteration: 46nanIteration: 47nanIteration: 48nanIteration: 49nanIteration: 50nanIteration: 51nanIteration: 52nanIteration: 53nanIteration: 54nanIteration: 55nanIteration: 56nanIteration: 57nanIteration: 58nanIteration: 59nanIteration: 60nanIteration: 61nanIteration: 62nanIteration: 63nanIteration: 64nanIteration: 65nanIteration: 66nanIteration: 67nanIteration: 68nanIteration: 69nanIteration: 70nanIteration: 71nanIteration: 72nanIteration: 73nanIteration: 74nanIteration: 75nanIteration: 76nanIteration: 77nanIteration: 78nanIteration: 79nanIteration: 80nanIteration: 81nanIteration: 82nanIteration: 83nanIteration: 84nanIteration: 85nanIteration: 86nanIteration: 87nanIteration: 88nanIteration: 89nanIteration: 90nanIteration: 91nanIteration: 92nanIteration: 93nanIteration: 94nanIteration: 95nanIteration: 96nanIteration: 97nanIteration: 98nanIteration: 99nan
回答:
你的实现没有任何问题。只是你的步长太大,无法收敛。你在优化过程中一直在优化峰值上跳跃,导致错误越来越大。将步长改为如下设置:
loss = check_reg.train(x=x,y=y, step_size=0.000001)
你会得到以下结果:
Iteration: 058305.102166924036Iteration: 125952.192344178206Iteration: 211551.585414406314Iteration: 35141.729521746186Iteration: 42288.6353484460747Iteration: 51018.6952280352172Iteration: 6453.4320214875039Iteration: 7201.82728832044089Iteration: 889.83519431606754Iteration: 939.98665864625944Iteration: 1017.798416262435936Iteration: 117.92229454258205Iteration: 123.526272890501929Iteration: 131.5696002444816197Iteration: 140.6986516574778796Iteration: 150.3109875219688626Iteration: 160.13843156434074647Iteration: 170.061616235257299326Iteration: 180.027424318402401473Iteration: 190.012205888201891543Iteration: 200.005434012356344396Iteration: 210.0024188644277583476Iteration: 220.0010770380211645404Iteration: 230.0004796730257022216Iteration: 240.00021339295719587025Iteration: 259.499628306355218e-05Iteration: 264.244764386691682e-05Iteration: 271.8965112443214162e-05Iteration: 288.56069334821767e-06Iteration: 293.848135476439999e-06Iteration: 301.7367004907528985e-06Iteration: 318.07976330965736e-07Iteration: 324.0167090640020525e-07Iteration: 332.253979336583221e-07Iteration: 341.5365746125585947e-07Iteration: 351.2480275459766612e-07Iteration: 361.1147859663321005e-07Iteration: 371.0288427880059631e-07Iteration: 381.0036079530613815e-07Iteration: 399.901975516098116e-08Iteration: 409.901971962009025e-08Iteration: 419.901968407922984e-08Iteration: 429.901964853839991e-08Iteration: 439.901961299760048e-08Iteration: 449.901957745683155e-08Iteration: 459.90195419160931e-08Iteration: 469.901950637538515e-08Iteration: 479.90194708347077e-08Iteration: 489.901943529406073e-08Iteration: 499.901939975344426e-08Iteration: 509.901936421285829e-08Iteration: 519.90193286723028e-08Iteration: 529.901929313177781e-08Iteration: 539.901925759128331e-08Iteration: 549.901922205081931e-08Iteration: 559.90191865103858e-08Iteration: 569.901915096998278e-08Iteration: 579.901911542961026e-08Iteration: 589.901907988926822e-08Iteration: 599.901904434895669e-08Iteration: 609.901900880867564e-08Iteration: 619.901897326842509e-08Iteration: 629.901893772820503e-08Iteration: 639.901890218801546e-08Iteration: 649.901886664785639e-08Iteration: 659.901883110772781e-08Iteration: 669.901879556762973e-08Iteration: 679.901876002756213e-08Iteration: 689.901872448752503e-08Iteration: 699.901868894751843e-08Iteration: 709.901865340754231e-08Iteration: 719.901861786759669e-08Iteration: 729.901858232768157e-08Iteration: 739.901854678779693e-08Iteration: 749.901851124794279e-08Iteration: 759.901847570811914e-08Iteration: 769.901844016832599e-08Iteration: 779.901840462856333e-08Iteration: 789.901836908883116e-08Iteration: 799.901833354912948e-08Iteration: 809.90182980094583e-08Iteration: 819.901826246981762e-08Iteration: 829.901822693020742e-08Iteration: 839.901819139062772e-08Iteration: 849.901815585107851e-08Iteration: 859.90181203115598e-08Iteration: 869.901808477207157e-08Iteration: 879.901804923261384e-08Iteration: 889.90180136931866e-08Iteration: 899.901797815378986e-08Iteration: 909.901794261442361e-08Iteration: 919.901790707508786e-08Iteration: 929.901787153578259e-08Iteration: 939.901783599650782e-08Iteration: 949.901780045726355e-08Iteration: 959.901776491804976e-08Iteration: 969.901772937886647e-08Iteration: 979.901769383971367e-08Iteration: 989.901765830059137e-08Iteration: 999.901762276149956e-08
希望这对你有帮助!