我开始学习Andrew NG在Coursera上的机器学习课程。我尝试实现梯度下降的线性回归
,但我不确定我错过了什么。根据这个
我尝试实现它,但似乎有些问题。这里是代码。值得指出的是,这是我第一次接触Python,没有学习基础知识。
import numpy as npimport matplotlib.pyplot as pltplt.ion()x = [1,2,3,4,5]y = [1,2,3,4,5]def Gradient_Descent(x, y, learning_rate, iterations): theta_1=np.random.randint(low=2, high=5); theta_0=np.random.randint(low=2, high=5); m = x.shape[0]def mean_error(a, b, factor): sum_mean = 0 for i in range(m): sum_mean += (theta_0 + theta_1 * a[i]) - b[i] # h(x) = (theta0 + theta1 * x) - y if factor: sum_mean *= a[i] return sum_meandef perform_cal(theta_0, theta_1, m): temp_0 = theta_0 - learning_rate * ((1 / m) * mean_error(x, y, False)) temp_1 = theta_1 - learning_rate * ((1 / m) * mean_error(x, y, True)) return temp_0 , temp_1fig = plt.figure()ax = fig.add_subplot(111)for i in range(iterations): theta_0, theta_1 = perform_cal(theta_0, theta_1, m) ax.clear() ax.plot(x, y, linestyle='None', marker='o') ax.plot(x, theta_0 + theta_1*x) fig.canvas.draw()x = np.array(x)y = np.array(y)Gradient_Descent(x,y, 0.1, 500)input("Press enter to close program")
我哪里做错了?
回答:
import numpy as npimport matplotlib.pyplot as pltplt.ion()x = [1,2,3,4,5]y = [1,2,3,4,5]def Gradient_Descent(x, y, learning_rate, iterations): theta_1=0 theta_0=0 m = x.shape[0] for i in range(iterations): theta_0, theta_1 = perform_cal(theta_0, theta_1, m, learning_rate) ax.clear() ax.plot(x, y, linestyle='None', marker='o') ax.plot(x, theta_0 + theta_1*x) fig.canvas.draw()def mean_error(a, b, factor, m, theta_0, theta_1): sum_mean = 0 for i in range(m): sum_mean += (theta_0 + theta_1 * a[i]) - b[i] # h(x) = (theta0 + theta1 * x) - y if factor: sum_mean *= a[i] print(sum_mean) return sum_meandef perform_cal(theta_0, theta_1, m, learning_rate): temp_0 = theta_0 - learning_rate * ((1 / m) * mean_error(x, y, False, m, theta_0, theta_1)) temp_1 = theta_1 - learning_rate * ((1 / m) * mean_error(x, y, True, m, theta_0, theta_1)) return temp_0 , temp_1fig = plt.figure()ax = fig.add_subplot(111)x = np.array(x)y = np.array(y)Gradient_Descent(x,y, 0.01, 100)
我对你的代码做了一些修改(主要是重新排列了几行,没有改变你做的任何事情,以免看起来令人困惑),现在它可以工作了。我建议你先学习语言的基础知识,因为大多数错误都是非常基本的,比如参数传递等。然而,你自己尝试实现Andrew Ng课程的内容是值得称赞的。